Skip to content

feature: jumpstart notebook utils -- list model ids, scripts, tasks, frameworks #2987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
### Features

* override jumpstart content bucket
* jumpstart model id suggestions
* jumpstart model ID suggestions
* adding customer metadata support to registermodel step

### Bug Fixes and Other Changes
Expand Down
4 changes: 2 additions & 2 deletions doc/doc_utils/jumpstart_doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def create_jumpstart_model_table():
file_content.append("==================================\n")
file_content.append(
"""
JumpStart for the SageMaker Python SDK uses model ids and model versions to access the necessary
JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary
utilities. This table serves to provide the core material plus some extra information that can be useful
in selecting the correct model id and corresponding parameters.\n
in selecting the correct model ID and corresponding parameters.\n
"""
)
file_content.append(
Expand Down
2 changes: 1 addition & 1 deletion doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ the ``model_id`` and ``model_version`` needed to retrieve the URI.
model. To use the latest version, enter ``"*"``. This is a
required parameter.

To retrieve a model, first select a ``model id`` and ``version`` from
To retrieve a model, first select a ``model ID`` and ``version`` from
the :doc:`available models <./doc_utils/jumpstart>`.

.. code:: python
Expand Down
29 changes: 24 additions & 5 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
Expand Down Expand Up @@ -84,8 +84,8 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
model_id (str): model ID to retrieve.
version (str): semantic version to retrieve for the model ID.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
Expand All @@ -101,8 +101,8 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
model_id (str): model ID to retrieve.
version (str): semantic version to retrieve for the model ID.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
Expand Down Expand Up @@ -150,3 +150,22 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

@staticmethod
def get_manifest(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
(Default: None).
region (str): Optional. The region to use for the cache.
(Default: None).
"""
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
18 changes: 9 additions & 9 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ def _get_manifest_key_from_model_id_semantic_version(
key: JumpStartVersionedModelId,
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
) -> JumpStartVersionedModelId:
"""Return model id and version in manifest that matches semantic version/id.
"""Return model ID and version in manifest that matches semantic version/id.

Uses ``packaging.version`` to perform version comparison. The highest model version
matching the semantic version is used, which is compatible with the SageMaker
version.

Args:
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
key (JumpStartVersionedModelId): Key for which to fetch versioned model ID.
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
old cached model id/version.
old cached model ID/version.

Raises:
KeyError: If the semantic version is not found in the manifest, or is found but
Expand Down Expand Up @@ -287,10 +287,10 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
return manifest

def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
"""Return header for a given JumpStart model id and semantic version.
"""Return header for a given JumpStart model ID and semantic version.

Args:
model_id (str): model id for which to get a header.
model_id (str): model ID for which to get a header.
semantic_version_str (str): The semantic version for which to get a
header.
"""
Expand Down Expand Up @@ -331,7 +331,7 @@ def _get_header_impl(
Allows a single retry if the cache is old.

Args:
model_id (str): model id for which to get a header.
model_id (str): model ID for which to get a header.
semantic_version_str (str): The semantic version for which to get a
header.
attempt (int): attempt number at retrieving a header.
Expand All @@ -353,10 +353,10 @@ def _get_header_impl(
return self._get_header_impl(model_id, semantic_version_str, attempt + 1)

def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs:
"""Return specs for a given JumpStart model id and semantic version.
"""Return specs for a given JumpStart model ID and semantic version.

Args:
model_id (str): model id for which to get specs.
model_id (str): model ID for which to get specs.
semantic_version_str (str): The semantic version for which to get
specs.
"""
Expand All @@ -369,6 +369,6 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
return specs # type: ignore

def clear(self) -> None:
"""Clears the model id/version and s3 cache."""
"""Clears the model ID/version and s3 cache."""
self._s3_cache.clear()
self._model_id_semantic_version_manifest_key_cache.clear()
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
"""Instantiates VulnerableJumpStartModelError exception.

Args:
model_id (Optional[str]): model id of vulnerable JumpStart model.
model_id (Optional[str]): model ID of vulnerable JumpStart model.
(Default: None).
version (Optional[str]): version of vulnerable JumpStart model.
(Default: None).
Expand Down
Loading