Skip to content

feature: Adding Jumpstart retrieval functions #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def read_version():
"packaging>=20.0",
"pandas",
"pathos",
"semantic-version",
]

# Specific use case dependencies
Expand Down
46 changes: 46 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from sagemaker import utils
from sagemaker.spark import defaults
from sagemaker.jumpstart import accessors as jumpstart_accessors

logger = logging.getLogger(__name__)

Expand All @@ -39,6 +40,8 @@ def retrieve(
distribution=None,
base_framework_version=None,
training_compiler_config=None,
model_id=None,
model_version=None,
):
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -69,13 +72,56 @@ def retrieve(
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
model_id (str): JumpStart model id for which to retrieve image URI.
model_version (str): JumpStart model version for which to retrieve image URI.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if model_id is not None or model_version is not None:
if model_id is None or model_version is None:
raise ValueError(
"Must specify `model_id` and `model_version` when getting image uri for "
"JumpStart models. "
)
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
region, model_id, model_version
)
if image_scope is None:
raise ValueError(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope == "inference":
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == "training":
if not model_specs.training_supported:
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
ecr_specs = model_specs.training_ecr_specs
else:
raise ValueError("JumpStart models only support inference and training.")

if framework is not None and framework != ecr_specs.framework:
raise ValueError(
f"Bad value for container framework for JumpStart model: '{framework}'."
)

return retrieve(
framework=ecr_specs.framework,
region=region,
version=ecr_specs.framework_version,
py_version=ecr_specs.py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope=image_scope,
container_version=container_version,
distribution=distribution,
base_framework_version=base_framework_version,
training_compiler_config=training_compiler_config,
)

if training_compiler_config is None:
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
elif framework == HUGGING_FACE_FRAMEWORK:
Expand Down
139 changes: 139 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Any, Dict, Optional
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


class SageMakerSettings(object):
"""Static class for storing the SageMaker settings."""

_PARSED_SAGEMAKER_VERSION = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: technically, this static attribute isn't a constant, so lower case might be more appropriate.


@staticmethod
def set_sagemaker_version(version: str) -> None:
"""Set SageMaker version."""
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version

@staticmethod
def get_sagemaker_version() -> str:
"""Return SageMaker version."""
return SageMakerSettings._PARSED_SAGEMAKER_VERSION


class JumpStartModelsCache(object):
"""Static class for storing the JumpStart models cache."""

_cache: Optional[cache.JumpStartModelsCache] = None
_curr_region = JUMPSTART_DEFAULT_REGION_NAME

_cache_kwargs: Dict[str, Any] = {}

@staticmethod
def _validate_region_cache_kwargs(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> Dict[str, Any]:
"""Returns cache_kwargs with region argument removed if present.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
assert isinstance(cache_kwargs_dict, dict)
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
)
del cache_kwargs_dict["region"]
return cache_kwargs_dict

@staticmethod
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
"""Returns model header from JumpStart models cache.

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
JumpStartModelsCache._cache_kwargs, region
)
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
JumpStartModelsCache._curr_region = region
assert JumpStartModelsCache._cache is not None
return JumpStartModelsCache._cache.get_header(model_id, version)

@staticmethod
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
JumpStartModelsCache._cache_kwargs, region
)
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
JumpStartModelsCache._curr_region = region
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: @evakravi This section is same as lines 80-82 above, can we refactor this to avoid duplicacy?

assert JumpStartModelsCache._cache is not None
return JumpStartModelsCache._cache.get_specs(model_id, version)

@staticmethod
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
"""Sets cache kwargs. Clears the cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (str): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
JumpStartModelsCache._cache_kwargs = cache_kwargs
if region is None:
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
**JumpStartModelsCache._cache_kwargs
)
else:
JumpStartModelsCache._curr_region = region
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
region=region, **JumpStartModelsCache._cache_kwargs
)

@staticmethod
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: str = None) -> None:
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (str): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsCache.set_cache_kwargs(cache_kwargs_dict, region)
Loading