Skip to content

feature: Adding Jumpstart retrieval functions #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def read_version():
"fabric>=2.0",
"requests>=2.20.0, <3",
"sagemaker-experiments",
"regex",
],
)

Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,8 +2433,8 @@ def training_image_uri(self, region=None):
training.

Args:
region: Region to use for image uri.
Default: Region associated with SageMaker session.
region (str): Optional. AWS region to use for image URI. Default: AWS region associated
with the SageMaker session.

Returns:
str: The URI of the Docker image.
Expand Down
11 changes: 8 additions & 3 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
):
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Ideally this function should not be called directly, rather it should be called from the
Expand Down Expand Up @@ -75,8 +75,10 @@ def retrieve(
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
model_id (str): JumpStart model id for which to retrieve image URI.
model_version (str): JumpStart model version for which to retrieve image URI.
model_id (str): JumpStart model ID for which to retrieve image URI
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.
Expand All @@ -85,8 +87,11 @@ def retrieve(
ValueError: If the combination of arguments specified is not supported.
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
Expand Down
20 changes: 11 additions & 9 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
class SageMakerSettings(object):
"""Static class for storing the SageMaker settings."""

_PARSED_SAGEMAKER_VERSION = ""
_parsed_sagemaker_version = ""

@staticmethod
def set_sagemaker_version(version: str) -> None:
"""Set SageMaker version."""
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version
SageMakerSettings._parsed_sagemaker_version = version

@staticmethod
def get_sagemaker_version() -> str:
"""Return SageMaker version."""
return SageMakerSettings._PARSED_SAGEMAKER_VERSION
return SageMakerSettings._parsed_sagemaker_version


class JumpStartModelsCache(object):
Expand All @@ -43,7 +43,7 @@ class JumpStartModelsCache(object):
_cache_kwargs: Dict[str, Any] = {}

@staticmethod
def _validate_region_cache_kwargs(
def _validate_and_mutate_region_cache_kwargs(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> Dict[str, Any]:
"""Returns cache_kwargs with region argument removed if present.
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
JumpStartModelsCache._cache_kwargs, region
)
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
Expand All @@ -92,7 +92,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
JumpStartModelsCache._cache_kwargs, region
)
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
Expand All @@ -103,7 +103,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS

@staticmethod
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
"""Sets cache kwargs, clear the cache.
"""Sets cache kwargs, clears the cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Expand All @@ -112,7 +112,9 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
cache_kwargs (str): cache kwargs to validate.
region (str): Optional. The region to validate along with the kwargs.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
cache_kwargs, region
)
JumpStartModelsCache._cache_kwargs = cache_kwargs
if region is None:
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
Expand All @@ -125,7 +127,7 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
)

@staticmethod
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: str = None) -> None:
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.

Raises:
Expand Down
65 changes: 45 additions & 20 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functions for obtainining JumpStart artifacts."""
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
from __future__ import absolute_import
from typing import Optional
from sagemaker import image_uris
Expand Down Expand Up @@ -42,13 +42,14 @@ def _retrieve_image_uri(
):
"""Retrieves the container image URI for JumpStart models.

Only `model_id` and `model_version` are required to be non-None;
Only `model_id`, `model_version`, and `image_scope` are required;
the rest of the fields are auto-populated.


Args:
model_id (str): JumpStart model id for which to retrieve image URI.
model_version (str): JumpStart model version for which to retrieve image URI.
model_id (str): JumpStart model ID for which to retrieve image URI.
model_version (str): Version of the JumpStart model for which to retrieve
the image URI (default: None).
framework (str): The name of the framework or algorithm.
region (str): The AWS region.
version (str): The framework or algorithm version. This is required if there is
Expand Down Expand Up @@ -89,7 +90,9 @@ def _retrieve_image_uri(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError("JumpStart models only support inference and training.")
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
region, model_id, model_version
Expand All @@ -99,25 +102,33 @@ def _retrieve_image_uri(
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_ecr_specs is not None
ecr_specs = model_specs.training_ecr_specs

if framework is not None and framework != ecr_specs.framework:
raise ValueError(f"Bad value for container framework for JumpStart model: '{framework}'.")
raise ValueError(
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
"and version {model_version}'."
)

if version is not None and version != ecr_specs.framework_version:
raise ValueError(
f"Bad value for container framework version for JumpStart model: '{version}'."
f"Incorrect container framework version '{version}' for JumpStart model ID "
f"'{model_id}' and version {model_version}'."
)

if py_version is not None and py_version != ecr_specs.py_version:
raise ValueError(
f"Bad value for container python version for JumpStart model: '{py_version}'."
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
"and version {model_version}'."
)

base_framework_version_override = None
version_override = None
base_framework_version_override: Optional[str] = None
version_override: Optional[str] = None
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
base_framework_version_override = ecr_specs.framework_version
version_override = ecr_specs.huggingface_transformers_version
Expand Down Expand Up @@ -162,8 +173,10 @@ def _retrieve_model_uri(
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Args:
model_id (str): JumpStart model id for which to retrieve model S3 URI.
model_version (str): JumpStart model version for which to retrieve model S3 URI.
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
the model artifact S3 URI.
model_version (str): Version of the JumpStart model for which to retrieve the model
artifact S3 URI.
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model S3 URI.
Expand All @@ -185,7 +198,9 @@ def _retrieve_model_uri(
)

if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError("JumpStart models only support inference and training.")
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
region, model_id, model_version
Expand All @@ -194,7 +209,10 @@ def _retrieve_model_uri(
model_artifact_key = model_specs.hosting_artifact_key
elif model_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_artifact_key is not None
model_artifact_key = model_specs.training_artifact_key

Expand All @@ -211,11 +229,13 @@ def _retrieve_script_uri(
script_scope: Optional[str],
region: Optional[str],
):
"""Retrieves the model script s3 URI for the model matching the given arguments.
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Args:
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the script S3 URI.
model_version (str): Version of the JumpStart model for which to
retrieve the model script S3 URI.
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
Expand All @@ -237,7 +257,9 @@ def _retrieve_script_uri(
)

if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError("JumpStart models only support inference and training.")
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
region, model_id, model_version
Expand All @@ -246,7 +268,10 @@ def _retrieve_script_uri(
model_script_key = model_specs.hosting_script_key
elif script_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_script_key is not None
model_script_key = model_specs.training_script_key

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _select_version(
spec = SpecifierSet(f"=={semantic_version_str}")
available_versions_filtered = list(spec.filter(available_versions))
return (
str(available_versions_filtered[0]) if available_versions_filtered != [] else None
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
)

def _get_header_impl(
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) ->
are None, and raises an exception if one argument is None but the other isn't.

Args:
model_id (str): Optional. Model id of JumpStart model.
version (str): Optional. Version for JumpStart model.
model_id (str): Optional. Model ID of the JumpStart model.
version (str): Optional. Version of the JumpStart model.

Raises:
ValueError: If only one of the two arguments is None.
Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for generating S3 model artifact URIs for pre-built SageMaker models."""
"""Accessors to retrieve the model artifact S3 URI of pretrained ML models."""
from __future__ import absolute_import

import logging
Expand All @@ -29,13 +29,15 @@ def retrieve(
model_id=None,
model_version: Optional[str] = None,
model_scope: Optional[str] = None,
):
) -> str:
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Args:
region (str): Region for which to retrieve model S3 URI.
model_id (str): JumpStart model id for which to retrieve model S3 URI.
model_version (str): JumpStart model version for which to retrieve model S3 URI.
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
the model artifact S3 URI.
model_version (str): Version of the JumpStart model for which to retrieve
the model artifact S3 URI.
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
Returns:
Expand All @@ -47,6 +49,8 @@ def retrieve(
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region)
16 changes: 11 additions & 5 deletions src/sagemaker/script_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for generating S3 model script URIs for pre-built SageMaker models."""
"""Accessors to retrieve the script S3 URI to be run pretrained ML models
in SageMaker containers.
"""
from __future__ import absolute_import

import logging
Expand All @@ -27,13 +29,15 @@ def retrieve(
model_id=None,
model_version=None,
script_scope=None,
):
"""Retrieves the model script s3 URI for the model matching the given arguments.
) -> str:
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Args:
region (str): Region for which to retrieve model script S3 URI.
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the script S3 URI.
model_version (str): Version of the JumpStart model for which to retrieve the
model script S3 URI.
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
Returns:
Expand All @@ -45,6 +49,8 @@ def retrieve(
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region)
Loading