Skip to content

feature: Adding Jumpstart retrieval functions #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ venv/
*.swp
.docker/
env/
.vscode/
.vscode/
**/tmp
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def read_version():
"fabric>=2.0",
"requests>=2.20.0, <3",
"sagemaker-experiments",
"regex",
],
)

Expand Down
49 changes: 15 additions & 34 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,51 +2425,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na

return init_params

def training_image_uri(self):
def training_image_uri(self, region=None):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
the model training, calls this method to find the image to use for model
training.

Args:
region: Region to use for image uri.
Default: Region associated with SageMaker session.

Returns:
str: The URI of the Docker image.
"""
if self.image_uri:
return self.image_uri
if hasattr(self, "distribution"):
distribution = self.distribution # pylint: disable=no-member
else:
distribution = None
compiler_config = getattr(self, "compiler_config", None)

if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if self.tensorflow_version is not None: # pylint: disable=no-member
base_framework_version = (
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
)
else:
base_framework_version = (
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
)
else:
container_version = None
base_framework_version = None

return image_uris.retrieve(
self._framework_name,
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version, # pylint: disable=no-member
return image_uris.get_training_image_uri(
region=region or self.sagemaker_session.boto_region_name,
framework=self._framework_name,
framework_version=self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
image_uri=self.image_uri,
distribution=getattr(self, "distribution", None),
compiler_config=getattr(self, "compiler_config", None),
tensorflow_version=getattr(self, "tensorflow_version", None),
pytorch_version=getattr(self, "pytorch_version", None),
instance_type=self.instance_type,
)

@classmethod
Expand Down
126 changes: 86 additions & 40 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import logging
import os
import re
from typing import Optional

from sagemaker import utils
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import accessors as jumpstart_accessors
from sagemaker.jumpstart import artifacts


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,45 +84,23 @@ def retrieve(
Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if model_id is not None or model_version is not None:
if model_id is None or model_version is None:
raise ValueError(
"Must specify `model_id` and `model_version` when getting image uri for "
"JumpStart models. "
)
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
region, model_id, model_version
)
if image_scope is None:
raise ValueError(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope == "inference":
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == "training":
if not model_specs.training_supported:
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
ecr_specs = model_specs.training_ecr_specs
else:
raise ValueError("JumpStart models only support inference and training.")

if framework is not None and framework != ecr_specs.framework:
raise ValueError(
f"Bad value for container framework for JumpStart model: '{framework}'."
)

return retrieve(
framework=ecr_specs.framework,
region=region,
version=ecr_specs.framework_version,
py_version=ecr_specs.py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope=image_scope,
container_version=container_version,
distribution=distribution,
base_framework_version=base_framework_version,
training_compiler_config=training_compiler_config,
if is_jumpstart_model_input(model_id, model_version):
assert model_id is not None
assert model_version is not None
return artifacts._retrieve_image_uri(
model_id,
model_version,
framework,
region,
version,
py_version,
instance_type,
accelerator_type,
image_scope,
container_version,
distribution,
base_framework_version,
training_compiler_config,
)

if training_compiler_config is None:
Expand Down Expand Up @@ -393,3 +374,68 @@ def _validate_arg(arg, available_options, arg_name):
def _format_tag(tag_prefix, processor, py_version, container_version):
"""Creates a tag for the image URI."""
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)


def get_training_image_uri(
region,
framework,
framework_version=None,
py_version=None,
image_uri=None,
distribution=None,
compiler_config=None,
tensorflow_version=None,
pytorch_version=None,
instance_type=None,
) -> str:
"""Retrieve image uri for training.

Args:
region (str): AWS region to use for image URI.
framework (str): The framework for which to retrieve an image URI.
framework_version (str): The framework version for which to retrieve an
image URI (default: None).
py_version (str): The python version to use for the image (default: None).
image_uri (str): If an image URI is supplied, it will be returned (default: None).
distribution (dict): A dictionary with information on how to run distributed
training (default: None).
compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
tensorflow_version (str): Version of tensorflow to use. (default: None)
pytorch_version (str): Version of pytorch to use. (default: None)
instance_type (str): Instance type fo use. (default: None)

Returns:
str: the image URI string.
"""

if image_uri:
return image_uri

base_framework_version: Optional[str] = None

if tensorflow_version is not None or pytorch_version is not None:
processor = _processor(instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if tensorflow_version is not None:
base_framework_version = f"tensorflow{tensorflow_version}"
else:
base_framework_version = f"pytorch{pytorch_version}"
else:
container_version = None
base_framework_version = None

return retrieve(
framework,
region,
instance_type=instance_type,
version=framework_version,
py_version=py_version,
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
)
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS

@staticmethod
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
"""Sets cache kwargs. Clears the cache.
"""Sets cache kwargs, clear the cache.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (str): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
region (str): Optional. The region to validate along with the kwargs.
"""
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
JumpStartModelsCache._cache_kwargs = cache_kwargs
Expand Down
Loading