Skip to content

Commit 4a94ace

Browse files
authored
feature: jumpstart notebook utils -- list model ids, scripts, tasks, frameworks (aws#2987)
1 parent d378f5d commit 4a94ace

15 files changed

+1968
-39
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
### Features
180180

181181
* override jumpstart content bucket
182-
* jumpstart model id suggestions
182+
* jumpstart model ID suggestions
183183
* adding customer metadata support to registermodel step
184184

185185
### Bug Fixes and Other Changes

doc/doc_utils/jumpstart_doc_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def create_jumpstart_model_table():
122122
file_content.append("==================================\n")
123123
file_content.append(
124124
"""
125-
JumpStart for the SageMaker Python SDK uses model ids and model versions to access the necessary
125+
JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary
126126
utilities. This table serves to provide the core material plus some extra information that can be useful
127-
in selecting the correct model id and corresponding parameters.\n
127+
in selecting the correct model ID and corresponding parameters.\n
128128
"""
129129
)
130130
file_content.append(

doc/overview.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ the ``model_id`` and ``model_version`` needed to retrieve the URI.
670670
model. To use the latest version, enter ``"*"``. This is a
671671
required parameter.
672672
673-
To retrieve a model, first select a ``model id`` and ``version`` from
673+
To retrieve a model, first select a ``model ID`` and ``version`` from
674674
the :doc:`available models <./doc_utils/jumpstart>`.
675675

676676
.. code:: python

src/sagemaker/jumpstart/accessors.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains accessors related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
1717
from sagemaker.jumpstart import cache
1818
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
@@ -84,8 +84,8 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
8484
8585
Args:
8686
region (str): region for which to retrieve header.
87-
model_id (str): model id to retrieve.
88-
version (str): semantic version to retrieve for the model id.
87+
model_id (str): model ID to retrieve.
88+
version (str): semantic version to retrieve for the model ID.
8989
"""
9090
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
9191
JumpStartModelsAccessor._cache_kwargs, region
@@ -101,8 +101,8 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
101101
102102
Args:
103103
region (str): region for which to retrieve header.
104-
model_id (str): model id to retrieve.
105-
version (str): semantic version to retrieve for the model id.
104+
model_id (str): model ID to retrieve.
105+
version (str): semantic version to retrieve for the model ID.
106106
"""
107107
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
108108
JumpStartModelsAccessor._cache_kwargs, region
@@ -150,3 +150,22 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
150150
"""
151151
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
152152
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
153+
154+
@staticmethod
155+
def get_manifest(
156+
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
157+
) -> List[JumpStartModelHeader]:
158+
"""Return entire JumpStart models manifest.
159+
160+
Raises:
161+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
162+
163+
Args:
164+
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
165+
(Default: None).
166+
region (str): Optional. The region to use for the cache.
167+
(Default: None).
168+
"""
169+
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
170+
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
171+
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore

src/sagemaker/jumpstart/cache.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,16 @@ def _get_manifest_key_from_model_id_semantic_version(
146146
key: JumpStartVersionedModelId,
147147
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
148148
) -> JumpStartVersionedModelId:
149-
"""Return model id and version in manifest that matches semantic version/id.
149+
"""Return model ID and version in manifest that matches semantic version/id.
150150
151151
Uses ``packaging.version`` to perform version comparison. The highest model version
152152
matching the semantic version is used, which is compatible with the SageMaker
153153
version.
154154
155155
Args:
156-
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
156+
key (JumpStartVersionedModelId): Key for which to fetch versioned model ID.
157157
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
158-
old cached model id/version.
158+
old cached model ID/version.
159159
160160
Raises:
161161
KeyError: If the semantic version is not found in the manifest, or is found but
@@ -287,10 +287,10 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
287287
return manifest
288288

289289
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
290-
"""Return header for a given JumpStart model id and semantic version.
290+
"""Return header for a given JumpStart model ID and semantic version.
291291
292292
Args:
293-
model_id (str): model id for which to get a header.
293+
model_id (str): model ID for which to get a header.
294294
semantic_version_str (str): The semantic version for which to get a
295295
header.
296296
"""
@@ -331,7 +331,7 @@ def _get_header_impl(
331331
Allows a single retry if the cache is old.
332332
333333
Args:
334-
model_id (str): model id for which to get a header.
334+
model_id (str): model ID for which to get a header.
335335
semantic_version_str (str): The semantic version for which to get a
336336
header.
337337
attempt (int): attempt number at retrieving a header.
@@ -353,10 +353,10 @@ def _get_header_impl(
353353
return self._get_header_impl(model_id, semantic_version_str, attempt + 1)
354354

355355
def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs:
356-
"""Return specs for a given JumpStart model id and semantic version.
356+
"""Return specs for a given JumpStart model ID and semantic version.
357357
358358
Args:
359-
model_id (str): model id for which to get specs.
359+
model_id (str): model ID for which to get specs.
360360
semantic_version_str (str): The semantic version for which to get
361361
specs.
362362
"""
@@ -369,6 +369,6 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
369369
return specs # type: ignore
370370

371371
def clear(self) -> None:
372-
"""Clears the model id/version and s3 cache."""
372+
"""Clears the model ID/version and s3 cache."""
373373
self._s3_cache.clear()
374374
self._model_id_semantic_version_manifest_key_cache.clear()

src/sagemaker/jumpstart/exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
"""Instantiates VulnerableJumpStartModelError exception.
5050
5151
Args:
52-
model_id (Optional[str]): model id of vulnerable JumpStart model.
52+
model_id (Optional[str]): model ID of vulnerable JumpStart model.
5353
(Default: None).
5454
version (Optional[str]): version of vulnerable JumpStart model.
5555
(Default: None).

0 commit comments

Comments
 (0)