Skip to content

feature: Adding Jumpstart retrieval functions #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ venv/
*.swp
.docker/
env/
.vscode/
.vscode/
**/tmp
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def read_version():
"packaging>=20.0",
"pandas",
"pathos",
"semantic-version",
]

# Specific use case dependencies
Expand Down
49 changes: 15 additions & 34 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,51 +2425,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na

return init_params

def training_image_uri(self):
def training_image_uri(self, region=None):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
the model training, calls this method to find the image to use for model
training.

Args:
region (str): Optional. AWS region to use for image URI. Default: AWS region associated
with the SageMaker session.

Returns:
str: The URI of the Docker image.
"""
if self.image_uri:
return self.image_uri
if hasattr(self, "distribution"):
distribution = self.distribution # pylint: disable=no-member
else:
distribution = None
compiler_config = getattr(self, "compiler_config", None)

if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if self.tensorflow_version is not None: # pylint: disable=no-member
base_framework_version = (
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
)
else:
base_framework_version = (
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
)
else:
container_version = None
base_framework_version = None

return image_uris.retrieve(
self._framework_name,
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version, # pylint: disable=no-member
return image_uris.get_training_image_uri(
region=region or self.sagemaker_session.boto_region_name,
framework=self._framework_name,
framework_version=self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
image_uri=self.image_uri,
distribution=getattr(self, "distribution", None),
compiler_config=getattr(self, "compiler_config", None),
tensorflow_version=getattr(self, "tensorflow_version", None),
pytorch_version=getattr(self, "pytorch_version", None),
instance_type=self.instance_type,
)

@classmethod
Expand Down
99 changes: 98 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
import logging
import os
import re
from typing import Optional

from sagemaker import utils
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts


logger = logging.getLogger(__name__)

Expand All @@ -39,7 +43,9 @@ def retrieve(
distribution=None,
base_framework_version=None,
training_compiler_config=None,
):
model_id=None,
model_version=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Ideally this function should not be called directly, rather it should be called from the
Expand Down Expand Up @@ -69,13 +75,39 @@ def retrieve(
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
model_id (str): JumpStart model ID for which to retrieve image URI
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
image_scope,
framework,
region,
version,
py_version,
instance_type,
accelerator_type,
container_version,
distribution,
base_framework_version,
training_compiler_config,
)

if training_compiler_config is None:
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
elif framework == HUGGING_FACE_FRAMEWORK:
Expand Down Expand Up @@ -347,3 +379,68 @@ def _validate_arg(arg, available_options, arg_name):
def _format_tag(tag_prefix, processor, py_version, container_version):
"""Creates a tag for the image URI."""
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)


def get_training_image_uri(
region,
framework,
framework_version=None,
py_version=None,
image_uri=None,
distribution=None,
compiler_config=None,
tensorflow_version=None,
pytorch_version=None,
instance_type=None,
) -> str:
"""Retrieve image uri for training.

Args:
region (str): AWS region to use for image URI.
framework (str): The framework for which to retrieve an image URI.
framework_version (str): The framework version for which to retrieve an
image URI (default: None).
py_version (str): The python version to use for the image (default: None).
image_uri (str): If an image URI is supplied, it will be returned (default: None).
distribution (dict): A dictionary with information on how to run distributed
training (default: None).
compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
tensorflow_version (str): Version of tensorflow to use. (default: None)
pytorch_version (str): Version of pytorch to use. (default: None)
instance_type (str): Instance type fo use. (default: None)

Returns:
str: the image URI string.
"""

if image_uri:
return image_uri

base_framework_version: Optional[str] = None

if tensorflow_version is not None or pytorch_version is not None:
processor = _processor(instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if tensorflow_version is not None:
base_framework_version = f"tensorflow{tensorflow_version}"
else:
base_framework_version = f"pytorch{pytorch_version}"
else:
container_version = None
base_framework_version = None

return retrieve(
framework,
region,
instance_type=instance_type,
version=framework_version,
py_version=py_version,
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
)
151 changes: 151 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Any, Dict, Optional
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


class SageMakerSettings(object):
"""Static class for storing the SageMaker settings."""

_parsed_sagemaker_version = ""

@staticmethod
def set_sagemaker_version(version: str) -> None:
"""Set SageMaker version."""
SageMakerSettings._parsed_sagemaker_version = version

@staticmethod
def get_sagemaker_version() -> str:
"""Return SageMaker version."""
return SageMakerSettings._parsed_sagemaker_version


class JumpStartModelsAccessor(object):
"""Static class for storing the JumpStart models cache."""

_cache: Optional[cache.JumpStartModelsCache] = None
_curr_region = JUMPSTART_DEFAULT_REGION_NAME

_cache_kwargs: Dict[str, Any] = {}

@staticmethod
def _validate_and_mutate_region_cache_kwargs(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> Dict[str, Any]:
"""Returns cache_kwargs with region argument removed if present.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
assert isinstance(cache_kwargs_dict, dict)
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
)
del cache_kwargs_dict["region"]
return cache_kwargs_dict

@staticmethod
def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
"""Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``.

Args:
region (str): region for which to retrieve header/spec.
cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
"""
if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region:
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
region=region, **cache_kwargs
)
JumpStartModelsAccessor._curr_region = region

@staticmethod
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
"""Returns model header from JumpStart models cache.

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_header(model_id, version)

@staticmethod
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.

Args:
region (str): region for which to retrieve header.
model_id (str): model id to retrieve.
version (str): semantic version to retrieve for the model id.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_specs(model_id, version)

@staticmethod
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
"""Sets cache kwargs, clears the cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (str): cache kwargs to validate.
region (str): Optional. The region to validate along with the kwargs.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
cache_kwargs, region
)
JumpStartModelsAccessor._cache_kwargs = cache_kwargs
if region is None:
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
**JumpStartModelsAccessor._cache_kwargs
)
else:
JumpStartModelsAccessor._curr_region = region
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
region=region, **JumpStartModelsAccessor._cache_kwargs
)

@staticmethod
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (str): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
Loading