Skip to content

Commit e499bc6

Browse files
committed
feat: jumpstart notebook utils
1 parent 7b45543 commit e499bc6

File tree

5 files changed

+954
-4
lines changed

5 files changed

+954
-4
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains accessors related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
1717
from sagemaker.jumpstart import cache
1818
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
@@ -150,3 +150,20 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
150150
"""
151151
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
152152
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
153+
154+
@staticmethod
155+
def get_manifest(
156+
cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None
157+
) -> List[JumpStartModelHeader]:
158+
"""Return entire JumpStart models manifest.
159+
160+
Raises:
161+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
162+
163+
Args:
164+
cache_kwargs (str): cache kwargs to use.
165+
region (str): The region to use for the cache.
166+
"""
167+
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
168+
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
169+
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores notebook utils related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
16+
from functools import cmp_to_key
17+
from typing import Any, Collection, List, Tuple, Union, Set, Dict
18+
from packaging.version import Version
19+
from sagemaker.jumpstart import accessors
20+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
21+
from sagemaker.jumpstart.enums import JumpStartScriptScope
22+
from sagemaker.jumpstart.utils import get_sagemaker_version
23+
24+
25+
def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
26+
"""Parse the input model id, return a tuple framework, task, rest-of-id.
27+
28+
Args:
29+
model_id (str): The model id for which to extract the framework/task/model.
30+
"""
31+
_id_parts = model_id.split("-")
32+
33+
if len(_id_parts) < 3:
34+
raise ValueError(f"incorrect model id: {model_id}.")
35+
36+
framework = _id_parts[0]
37+
task = _id_parts[1]
38+
name = "-".join(_id_parts[2:])
39+
40+
return framework, task, name
41+
42+
43+
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
44+
model_version_1: Tuple[str, str] = None, model_version_2: Tuple[str, str] = None
45+
) -> int:
46+
"""Performs comparison of sdk specs paths, in order to sort them in manifest.
47+
48+
Args:
49+
model_version_1 (Tuple[str, str]): The first model id and version tuple to compare.
50+
model_version_2 (Tuple[str, str]): The second model id and version tuple to compare.
51+
"""
52+
if model_version_1 is None or model_version_2 is None:
53+
if model_version_2 is not None:
54+
return -1
55+
if model_version_1 is not None:
56+
return 1
57+
return 0
58+
59+
version_1 = model_version_1[1]
60+
model_id_1 = model_version_1[0]
61+
62+
version_2 = model_version_2[1]
63+
model_id_2 = model_version_2[0]
64+
65+
if model_id_1 < model_id_2:
66+
return -1
67+
68+
if model_id_2 < model_id_1:
69+
return 1
70+
71+
if Version(version_1) < Version(version_2):
72+
return 1
73+
74+
if Version(version_2) < Version(version_1):
75+
return -1
76+
77+
return 0
78+
79+
80+
def list_jumpstart_frameworks(
81+
**kwargs: Dict[str, Any],
82+
) -> List[str]:
83+
"""List frameworks actively in use by JumpStart.
84+
85+
Args:
86+
kwargs (Dict[str, Any]): kwarg arguments to supply to
87+
``list_jumpstart_models``.
88+
"""
89+
models_list = list_jumpstart_models(**kwargs)
90+
frameworks = set()
91+
for model_id, _ in models_list:
92+
framework, _, _ = extract_framework_task_model(model_id)
93+
frameworks.add(framework)
94+
return sorted(list(frameworks))
95+
96+
97+
def list_jumpstart_tasks(
98+
**kwargs: Dict[str, Any],
99+
) -> List[str]:
100+
"""List tasks actively in use by JumpStart.
101+
102+
Args:
103+
kwargs (Dict[str, Any]): kwarg arguments to supply to
104+
``list_jumpstart_models``.
105+
"""
106+
models_list = list_jumpstart_models(**kwargs)
107+
tasks = set()
108+
for model_id, _ in models_list:
109+
_, task, _ = extract_framework_task_model(model_id)
110+
tasks.add(task)
111+
return sorted(list(tasks))
112+
113+
114+
def list_jumpstart_scripts(
115+
**kwargs: Dict[str, Any],
116+
) -> List[str]:
117+
"""List scripts actively in use by JumpStart.
118+
119+
Note: Using this function will result in slow execution speed, as it requires
120+
making many http calls and parsing metadata files. To-Do: store script
121+
information for all models in a single file.
122+
123+
Check ``sagemaker.jumpstart.enums.JumpStartScriptScope`` for possible types
124+
of JumpStart scripts.
125+
126+
Args:
127+
kwargs (Dict[str, Any]): kwarg arguments to supply to
128+
``list_jumpstart_models``.
129+
"""
130+
models_list = list_jumpstart_models(**kwargs)
131+
scripts = set()
132+
for model_id, version in models_list:
133+
scripts.add(JumpStartScriptScope.INFERENCE.value)
134+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
135+
region=kwargs.get("region", JUMPSTART_DEFAULT_REGION_NAME),
136+
model_id=model_id,
137+
version=version,
138+
)
139+
if model_specs.training_supported:
140+
scripts.add(JumpStartScriptScope.TRAINING.value)
141+
142+
if scripts == {e.value for e in JumpStartScriptScope}:
143+
break
144+
return sorted(list(scripts))
145+
146+
147+
def list_jumpstart_models(
148+
script_allowlist: Union[str, Collection[str]] = None,
149+
task_allowlist: Union[str, Collection[str]] = None,
150+
framework_allowlist: Union[str, Collection[str]] = None,
151+
model_id_allowlist: Union[str, Collection[str]] = None,
152+
script_denylist: Union[str, Collection[str]] = None,
153+
task_denylist: Union[str, Collection[str]] = None,
154+
framework_denylist: Union[str, Collection[str]] = None,
155+
model_id_denylist: Union[str, Collection[str]] = None,
156+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
157+
accept_unsupported_models: bool = False,
158+
accept_old_models: bool = False,
159+
accept_vulnerable_models: bool = True,
160+
accept_deprecated_models: bool = True,
161+
) -> List[str]:
162+
"""List models in JumpStart, and optionally apply filters to result.
163+
164+
Args:
165+
script_allowlist (Union[str, Collection[str]]): Optional. String or
166+
``Collection`` storing scripts. All models returned by this function
167+
must use a script which is specified in this argument. Note: Using this
168+
filter will result in slow execution speed, as it requires making more
169+
http calls and parsing many metadata files. To-Do: store script
170+
information for all models in a single file.
171+
(Default: None).
172+
task_allowlist (Union[str, Collection[str]]): Optional. String or
173+
``Collection`` storing tasks. All models returned by this function
174+
must use a task which is specified in this argument.
175+
(Default: None).
176+
framework_allowlist (Union[str, Collection[str]]): Optional. String or
177+
``Collection`` storing frameworks. All models returned by this function
178+
must use a frameworks which is specified in this argument.
179+
(Default: None).
180+
model_id_allowlist (Union[str, Collection[str]]): Optional. String or
181+
``Collection`` storing model ids. All models returned by this function
182+
must use a model id which is specified in this argument.
183+
(Default: None).
184+
script_denylist (Union[str, Collection[str]]): Optional. String or
185+
``Collection`` storing scripts. All models returned by this function
186+
must not use a script which is specified in this argument. Note: Using
187+
this filter will result in slow execution speed, as it requires making
188+
more http calls and parsing many metadata files. To-Do: store script
189+
information for all models in a single file.
190+
(Default: None).
191+
task_denylist (Union[str, Collection[str]]): Optional. String or
192+
``Collection`` storing tasks. All models returned by this function
193+
must not use a task which is specified in this argument.
194+
(Default: None).
195+
framework_denylist (Union[str, Collection[str]]): Optional. String or
196+
``Collection`` storing frameworks. All models returned by this function
197+
must not use a frameworks which is specified in this argument.
198+
(Default: None).
199+
model_id_denylist (Union[str, Collection[str]]): Optional. String or
200+
``Collection`` storing scripts. All models returned by this function
201+
must not use a model id which is specified in this argument.
202+
(Default: None).
203+
region (str): Optional. Region to use when fetching JumpStart metadata.
204+
(Default: ``JUMPSTART_DEFAULT_REGION_NAME``).
205+
accept_unsupported_models (bool): Optional. Set to True to accept models that
206+
are not supported with the current SageMaker library version
207+
(Default: False).
208+
accept_old_models (bool): Optional. Set to True to accept model and version
209+
tuples for which a model with the same name and a newer version exists.
210+
(Default: False).
211+
accept_vulnerable_models (bool): Optional. Set to False to reject models that
212+
have a vulnerable inference or training script dependency. Note: accessing
213+
vulnerability information requires making many http calls and parsing many
214+
metadata files. To-Do: store vulnerability information for all models in a
215+
single file, and change default value to False. (Default: True).
216+
accept_deprecated_models (bool): Optional. Set to False to reject models that
217+
have been flagged as deprecated. Note: accessing deprecation information
218+
requires making many http calls and parsing many metadata files. To-Do:
219+
store deprecation information for all models in a single file, and change
220+
default value to False. (Default: True).
221+
"""
222+
bad_script_filter = script_allowlist is not None and script_denylist is not None
223+
bad_task_filter = task_allowlist is not None and task_denylist is not None
224+
bad_framework_filter = framework_allowlist is not None and framework_denylist is not None
225+
bad_model_id_filter = model_id_allowlist is not None and model_id_denylist is not None
226+
227+
if bad_script_filter or bad_task_filter or bad_framework_filter or bad_model_id_filter:
228+
raise ValueError(
229+
(
230+
"Cannot use an allowlist and denylist at the same time "
231+
"for a filter (script, task, framework, model id)"
232+
)
233+
)
234+
235+
if isinstance(script_allowlist, str):
236+
script_allowlist = set([script_allowlist])
237+
238+
if isinstance(task_allowlist, str):
239+
task_allowlist = set([task_allowlist])
240+
241+
if isinstance(framework_allowlist, str):
242+
framework_allowlist = set([framework_allowlist])
243+
244+
if isinstance(model_id_allowlist, str):
245+
model_id_allowlist = set([model_id_allowlist])
246+
247+
if isinstance(script_denylist, str):
248+
script_denylist = set([script_denylist])
249+
250+
if isinstance(task_denylist, str):
251+
task_denylist = set([task_denylist])
252+
253+
if isinstance(framework_denylist, str):
254+
framework_denylist = set([framework_denylist])
255+
256+
if isinstance(model_id_denylist, str):
257+
model_id_denylist = set([model_id_denylist])
258+
259+
models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region)
260+
model_id_version_dict: Dict[str, Set[str]] = dict()
261+
for model_manifest in models_manifest_list:
262+
model_id = model_manifest.model_id
263+
model_version = model_manifest.version
264+
265+
if not accept_unsupported_models and Version(get_sagemaker_version()) < Version(
266+
model_manifest.min_version
267+
):
268+
continue
269+
270+
if model_id_allowlist is not None:
271+
model_id_allowlist = set(model_id_allowlist)
272+
if model_id not in model_id_allowlist:
273+
continue
274+
if model_id_denylist is not None:
275+
model_id_denylist = set(model_id_denylist)
276+
if model_id in model_id_denylist:
277+
continue
278+
279+
framework, task, _ = extract_framework_task_model(model_id)
280+
supported_scripts = set([JumpStartScriptScope.INFERENCE.value])
281+
282+
if task_allowlist is not None:
283+
task_allowlist = set(task_allowlist)
284+
if task not in task_allowlist:
285+
continue
286+
if task_denylist is not None:
287+
task_denylist = set(task_denylist)
288+
if task in task_denylist:
289+
continue
290+
291+
if framework_allowlist is not None:
292+
framework_allowlist = set(framework_allowlist)
293+
if framework not in framework_allowlist:
294+
continue
295+
if framework_denylist is not None:
296+
framework_denylist = set(framework_denylist)
297+
if framework in framework_denylist:
298+
continue
299+
300+
if script_denylist is not None or script_allowlist is not None:
301+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
302+
region=region, model_id=model_id, version=model_version
303+
)
304+
if model_specs.training_supported:
305+
supported_scripts.add(JumpStartScriptScope.TRAINING.value)
306+
307+
if script_allowlist is not None:
308+
script_allowlist = set(script_allowlist)
309+
if len(supported_scripts.intersection(script_allowlist)) == 0:
310+
continue
311+
if script_denylist is not None:
312+
script_denylist = set(script_denylist)
313+
if len(supported_scripts.intersection(script_denylist)) > 0:
314+
continue
315+
if not accept_vulnerable_models:
316+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
317+
region=region, model_id=model_id, version=model_version
318+
)
319+
if model_specs.inference_vulnerable or model_specs.training_vulnerable:
320+
continue
321+
if not accept_deprecated_models:
322+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
323+
region=region, model_id=model_id, version=model_version
324+
)
325+
if model_specs.deprecated:
326+
continue
327+
328+
if model_id not in model_id_version_dict:
329+
model_id_version_dict[model_id] = set()
330+
331+
model_id_version_dict[model_id].add(Version(model_version))
332+
333+
if not accept_old_models:
334+
model_id_version_dict = {
335+
model: set([max(versions)]) for model, versions in model_id_version_dict.items()
336+
}
337+
338+
model_id_set = set()
339+
for model_id in model_id_version_dict:
340+
for version in model_id_version_dict[model_id]:
341+
model_id_set.add((model_id, str(version)))
342+
343+
return sorted(list(model_id_set), key=cmp_to_key(_compare_model_version_tuples))

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def test_jumpstart_models_cache_get_fxs():
5151
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
5252
)
5353

54+
assert len(accessors.JumpStartModelsAccessor.get_manifest()) > 0
55+
5456
# necessary because accessors is a static module
5557
reload(accessors)
5658

0 commit comments

Comments
 (0)