Skip to content

Commit b09793a

Browse files
feature: client cache for jumpstart models
Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent 2beb91e commit b09793a

File tree

16 files changed

+2115
-1
lines changed

16 files changed

+2115
-1
lines changed

.coveragerc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[run]
2-
concurrency = threading
2+
concurrency = thread
33
omit = sagemaker/tests/*
44
timid = True

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def read_version():
4444
"packaging>=20.0",
4545
"pandas",
4646
"pathos",
47+
"semantic-version",
4748
]
4849

4950
# Specific use case dependencies

src/sagemaker/jumpstart/__init__.py

Whitespace-only changes.

src/sagemaker/jumpstart/cache.py

+327
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module defines the JumpStartModelsCache class."""
14+
from __future__ import absolute_import
15+
import datetime
16+
from typing import List, Optional
17+
import json
18+
import boto3
19+
import botocore
20+
import semantic_version
21+
from sagemaker.jumpstart.constants import (
22+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
23+
JUMPSTART_DEFAULT_REGION_NAME,
24+
)
25+
from sagemaker.jumpstart.parameters import (
26+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
27+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
28+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
29+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
30+
)
31+
from sagemaker.jumpstart.types import (
32+
JumpStartCachedS3ContentKey,
33+
JumpStartCachedS3ContentValue,
34+
JumpStartModelHeader,
35+
JumpStartModelSpecs,
36+
JumpStartS3FileType,
37+
JumpStartVersionedModelId,
38+
)
39+
from sagemaker.jumpstart import utils
40+
from sagemaker.utilities.cache import LRUCache
41+
42+
43+
class JumpStartModelsCache:
44+
"""Class that implements a cache for JumpStart models manifests and specs.
45+
46+
The manifest and specs associated with JumpStart models provide the information necessary
47+
for launching JumpStart models from the SageMaker SDK.
48+
"""
49+
50+
def __init__(
51+
self,
52+
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
53+
max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
54+
s3_cache_expiration_horizon: Optional[
55+
datetime.timedelta
56+
] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
57+
max_semantic_version_cache_items: Optional[
58+
int
59+
] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
60+
semantic_version_cache_expiration_horizon: Optional[
61+
datetime.timedelta
62+
] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
63+
manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
64+
s3_bucket_name: Optional[str] = None,
65+
s3_client_config: Optional[botocore.config.Config] = None,
66+
) -> None:
67+
"""Initialize a ``JumpStartModelsCache`` instance.
68+
69+
Args:
70+
region (Optional[str]): AWS region to associate with cache. Default: region associated
71+
with boto3 session.
72+
max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
73+
Default: 20.
74+
s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
75+
items in s3 cache before invalidation. Default: 6 hours.
76+
max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
77+
semantic version cache. Default: 20.
78+
semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
79+
Maximum time to hold items in semantic version cache before invalidation.
80+
Default: 6 hours.
81+
s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
82+
Default: JumpStart-hosted content bucket for region.
83+
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
84+
Default: None (no config).
85+
"""
86+
87+
self._region = region
88+
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
89+
max_cache_items=max_s3_cache_items,
90+
expiration_horizon=s3_cache_expiration_horizon,
91+
retrieval_function=self._get_file_from_s3,
92+
)
93+
self._model_id_semantic_version_manifest_key_cache = LRUCache[
94+
JumpStartVersionedModelId, JumpStartVersionedModelId
95+
](
96+
max_cache_items=max_semantic_version_cache_items,
97+
expiration_horizon=semantic_version_cache_expiration_horizon,
98+
retrieval_function=self._get_manifest_key_from_model_id_semantic_version,
99+
)
100+
self._manifest_file_s3_key = manifest_file_s3_key
101+
self.s3_bucket_name = (
102+
utils.get_jumpstart_content_bucket(self._region)
103+
if s3_bucket_name is None
104+
else s3_bucket_name
105+
)
106+
self._s3_client = (
107+
boto3.client("s3", region_name=self._region, config=s3_client_config)
108+
if s3_client_config
109+
else boto3.client("s3", region_name=self._region)
110+
)
111+
112+
def set_region(self, region: str) -> None:
113+
"""Set region for cache. Clears cache after new region is set."""
114+
if region != self._region:
115+
self._region = region
116+
self.clear()
117+
118+
def get_region(self) -> str:
119+
"""Return region for cache."""
120+
return self._region
121+
122+
def set_manifest_file_s3_key(self, key: str) -> None:
123+
"""Set manifest file s3 key. Clears cache after new key is set."""
124+
if key != self._manifest_file_s3_key:
125+
self._manifest_file_s3_key = key
126+
self.clear()
127+
128+
def get_manifest_file_s3_key(self) -> None:
129+
"""Return manifest file s3 key for cache."""
130+
return self._manifest_file_s3_key
131+
132+
def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
133+
"""Set s3 bucket used for cache."""
134+
if s3_bucket_name != self.s3_bucket_name:
135+
self.s3_bucket_name = s3_bucket_name
136+
self.clear()
137+
138+
def get_bucket(self) -> None:
139+
"""Return bucket used for cache."""
140+
return self.s3_bucket_name
141+
142+
def _get_manifest_key_from_model_id_semantic_version(
143+
self,
144+
key: JumpStartVersionedModelId,
145+
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
146+
) -> JumpStartVersionedModelId:
147+
"""Return model id and version in manifest that matches semantic version/id.
148+
149+
Uses ``semantic_version`` to perform version comparison. The highest model version
150+
matching the semantic version is used, which is compatible with the SageMaker
151+
version.
152+
153+
Args:
154+
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
155+
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
156+
old cached model id/version.
157+
158+
Raises:
159+
KeyError: If the semantic version is not found in the manifest, or is found but
160+
the SageMaker version needs to be upgraded in order for the model to be used.
161+
"""
162+
163+
model_id, version = key.model_id, key.version
164+
165+
manifest = self._s3_cache.get(
166+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
167+
).formatted_content
168+
169+
sm_version = utils.get_sagemaker_version()
170+
171+
versions_compatible_with_sagemaker = [
172+
semantic_version.Version(header.version)
173+
for header in manifest.values()
174+
if header.model_id == model_id
175+
and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version)
176+
]
177+
178+
spec = (
179+
semantic_version.SimpleSpec("*")
180+
if version is None
181+
else semantic_version.SimpleSpec(version)
182+
)
183+
184+
sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker)
185+
if sm_compatible_model_version is not None:
186+
return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version))
187+
188+
versions_incompatible_with_sagemaker = [
189+
semantic_version.Version(header.version)
190+
for header in manifest.values()
191+
if header.model_id == model_id
192+
]
193+
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
194+
if sm_incompatible_model_version is not None:
195+
model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version)
196+
sm_version_to_use = [
197+
header.min_version
198+
for header in manifest.values()
199+
if header.model_id == model_id
200+
and header.version == model_version_to_use_incompatible_with_sagemaker
201+
]
202+
if len(sm_version_to_use) != 1:
203+
# ``manifest`` dict should already enforce this
204+
raise RuntimeError("Found more than one incompatible SageMaker version to use.")
205+
sm_version_to_use = sm_version_to_use[0]
206+
207+
error_msg = (
208+
f"Unable to find model manifest for {model_id} with version {version} "
209+
f"compatible with your SageMaker version ({sm_version}). "
210+
f"Consider upgrading your SageMaker library to at least version "
211+
f"{sm_version_to_use} so you can use version "
212+
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
213+
)
214+
raise KeyError(error_msg)
215+
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
216+
raise KeyError(error_msg)
217+
218+
def _get_file_from_s3(
219+
self,
220+
key: JumpStartCachedS3ContentKey,
221+
value: Optional[JumpStartCachedS3ContentValue],
222+
) -> JumpStartCachedS3ContentValue:
223+
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
224+
225+
If a manifest file is being fetched, we only download the object if the md5 hash in
226+
``head_object`` does not match the current md5 hash for the stored value. This prevents
227+
unnecessarily downloading the full manifest when it hasn't changed.
228+
229+
Args:
230+
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
231+
value (Optional[JumpStartVersionedModelId]): Current value of old cached
232+
s3 content. This is used for the manifest file, so that it is only
233+
downloaded when its content changes.
234+
"""
235+
236+
file_type, s3_key = key.file_type, key.s3_key
237+
238+
if file_type == JumpStartS3FileType.MANIFEST:
239+
if value is not None:
240+
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
241+
if etag == value.md5_hash:
242+
return value
243+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
244+
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
245+
etag = response["ETag"]
246+
return JumpStartCachedS3ContentValue(
247+
formatted_content=utils.get_formatted_manifest(formatted_body),
248+
md5_hash=etag,
249+
)
250+
if file_type == JumpStartS3FileType.SPECS:
251+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
252+
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
253+
return JumpStartCachedS3ContentValue(
254+
formatted_content=JumpStartModelSpecs(formatted_body)
255+
)
256+
raise ValueError(
257+
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
258+
)
259+
260+
def get_manifest(self) -> List[JumpStartModelHeader]:
261+
"""Return entire JumpStart models manifest."""
262+
263+
return self._s3_cache.get(
264+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
265+
).formatted_content.values()
266+
267+
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
268+
"""Return header for a given JumpStart model id and semantic version.
269+
270+
Args:
271+
model_id (str): model id for which to get a header.
272+
semantic_version_str (str): The semantic version for which to get a
273+
header.
274+
"""
275+
276+
return self._get_header_impl(model_id, semantic_version_str=semantic_version_str)
277+
278+
def _get_header_impl(
279+
self,
280+
model_id: str,
281+
semantic_version_str: str,
282+
attempt: Optional[int] = 0,
283+
) -> JumpStartModelHeader:
284+
"""Lower-level function to return header.
285+
286+
Allows a single retry if the cache is old.
287+
288+
Args:
289+
model_id (str): model id for which to get a header.
290+
semantic_version_str (str): The semantic version for which to get a
291+
header.
292+
attempt (Optional[int]): attempt number at retrieving a header.
293+
"""
294+
295+
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
296+
JumpStartVersionedModelId(model_id, semantic_version_str)
297+
)
298+
manifest = self._s3_cache.get(
299+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
300+
).formatted_content
301+
try:
302+
return manifest[versioned_model_id]
303+
except KeyError:
304+
if attempt > 0:
305+
raise
306+
self.clear()
307+
return self._get_header_impl(model_id, semantic_version_str, attempt + 1)
308+
309+
def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs:
310+
"""Return specs for a given JumpStart model id and semantic version.
311+
312+
Args:
313+
model_id (str): model id for which to get specs.
314+
semantic_version_str (str): The semantic version for which to get
315+
specs.
316+
"""
317+
318+
header = self.get_header(model_id, semantic_version_str)
319+
spec_key = header.spec_key
320+
return self._s3_cache.get(
321+
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
322+
).formatted_content
323+
324+
def clear(self) -> None:
325+
"""Clears the model id/version and s3 cache."""
326+
self._s3_cache.clear()
327+
self._model_id_semantic_version_manifest_key_cache.clear()

src/sagemaker/jumpstart/constants.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores constants related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
from typing import Set
16+
import boto3
17+
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
18+
19+
20+
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
21+
22+
JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
23+
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS
24+
}
25+
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
26+
27+
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
28+
29+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

src/sagemaker/jumpstart/parameters.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores parameters related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
import datetime
16+
17+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20
18+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
19+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
20+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)

0 commit comments

Comments
 (0)