Skip to content

Commit b691d3d

Browse files
authored
feature: Adding Jumpstart retrieval functions (aws#2789)
1 parent ac57772 commit b691d3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+3459
-239
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ venv/
2828
.docker/
2929
env/
3030
.vscode/
31+
**/tmp
3132
.python-version

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def read_version():
4444
"packaging>=20.0",
4545
"pandas",
4646
"pathos",
47-
"semantic-version",
4847
]
4948

5049
# Specific use case dependencies

src/sagemaker/estimator.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,51 +2426,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
24262426

24272427
return init_params
24282428

2429-
def training_image_uri(self):
2429+
def training_image_uri(self, region=None):
24302430
"""Return the Docker image to use for training.
24312431
24322432
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
24332433
the model training, calls this method to find the image to use for model
24342434
training.
24352435
2436+
Args:
2437+
region (str): Optional. AWS region to use for image URI. Default: AWS region associated
2438+
with the SageMaker session.
2439+
24362440
Returns:
24372441
str: The URI of the Docker image.
24382442
"""
2439-
if self.image_uri:
2440-
return self.image_uri
2441-
if hasattr(self, "distribution"):
2442-
distribution = self.distribution # pylint: disable=no-member
2443-
else:
2444-
distribution = None
2445-
compiler_config = getattr(self, "compiler_config", None)
2446-
2447-
if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
2448-
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
2449-
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
2450-
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
2451-
if self.tensorflow_version is not None: # pylint: disable=no-member
2452-
base_framework_version = (
2453-
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
2454-
)
2455-
else:
2456-
base_framework_version = (
2457-
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
2458-
)
2459-
else:
2460-
container_version = None
2461-
base_framework_version = None
24622443

2463-
return image_uris.retrieve(
2464-
self._framework_name,
2465-
self.sagemaker_session.boto_region_name,
2466-
instance_type=self.instance_type,
2467-
version=self.framework_version, # pylint: disable=no-member
2444+
return image_uris.get_training_image_uri(
2445+
region=region or self.sagemaker_session.boto_region_name,
2446+
framework=self._framework_name,
2447+
framework_version=self.framework_version, # pylint: disable=no-member
24682448
py_version=self.py_version, # pylint: disable=no-member
2469-
image_scope="training",
2470-
distribution=distribution,
2471-
base_framework_version=base_framework_version,
2472-
container_version=container_version,
2473-
training_compiler_config=compiler_config,
2449+
image_uri=self.image_uri,
2450+
distribution=getattr(self, "distribution", None),
2451+
compiler_config=getattr(self, "compiler_config", None),
2452+
tensorflow_version=getattr(self, "tensorflow_version", None),
2453+
pytorch_version=getattr(self, "pytorch_version", None),
2454+
instance_type=self.instance_type,
24742455
)
24752456

24762457
@classmethod

src/sagemaker/image_uris.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
import logging
1818
import os
1919
import re
20+
from typing import Optional
2021

2122
from sagemaker import utils
23+
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2224
from sagemaker.spark import defaults
25+
from sagemaker.jumpstart import artifacts
26+
2327

2428
logger = logging.getLogger(__name__)
2529

@@ -39,7 +43,9 @@ def retrieve(
3943
distribution=None,
4044
base_framework_version=None,
4145
training_compiler_config=None,
42-
):
46+
model_id=None,
47+
model_version=None,
48+
) -> str:
4349
"""Retrieves the ECR URI for the Docker image matching the given arguments.
4450
4551
Ideally this function should not be called directly, rather it should be called from the
@@ -69,13 +75,39 @@ def retrieve(
6975
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7076
A configuration class for the SageMaker Training Compiler
7177
(default: None).
78+
model_id (str): JumpStart model ID for which to retrieve image URI
79+
(default: None).
80+
model_version (str): Version of the JumpStart model for which to retrieve the
81+
image URI (default: None).
7282
7383
Returns:
7484
str: the ECR URI for the corresponding SageMaker Docker image.
7585
7686
Raises:
7787
ValueError: If the combination of arguments specified is not supported.
7888
"""
89+
if is_jumpstart_model_input(model_id, model_version):
90+
91+
# adding assert statements to satisfy mypy type checker
92+
assert model_id is not None
93+
assert model_version is not None
94+
95+
return artifacts._retrieve_image_uri(
96+
model_id,
97+
model_version,
98+
image_scope,
99+
framework,
100+
region,
101+
version,
102+
py_version,
103+
instance_type,
104+
accelerator_type,
105+
container_version,
106+
distribution,
107+
base_framework_version,
108+
training_compiler_config,
109+
)
110+
79111
if training_compiler_config is None:
80112
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
81113
elif framework == HUGGING_FACE_FRAMEWORK:
@@ -347,3 +379,68 @@ def _validate_arg(arg, available_options, arg_name):
347379
def _format_tag(tag_prefix, processor, py_version, container_version):
348380
"""Creates a tag for the image URI."""
349381
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
382+
383+
384+
def get_training_image_uri(
385+
region,
386+
framework,
387+
framework_version=None,
388+
py_version=None,
389+
image_uri=None,
390+
distribution=None,
391+
compiler_config=None,
392+
tensorflow_version=None,
393+
pytorch_version=None,
394+
instance_type=None,
395+
) -> str:
396+
"""Retrieve image uri for training.
397+
398+
Args:
399+
region (str): AWS region to use for image URI.
400+
framework (str): The framework for which to retrieve an image URI.
401+
framework_version (str): The framework version for which to retrieve an
402+
image URI (default: None).
403+
py_version (str): The python version to use for the image (default: None).
404+
image_uri (str): If an image URI is supplied, it will be returned (default: None).
405+
distribution (dict): A dictionary with information on how to run distributed
406+
training (default: None).
407+
compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
408+
A configuration class for the SageMaker Training Compiler
409+
(default: None).
410+
tensorflow_version (str): Version of tensorflow to use. (default: None)
411+
pytorch_version (str): Version of pytorch to use. (default: None)
412+
instance_type (str): Instance type fo use. (default: None)
413+
414+
Returns:
415+
str: the image URI string.
416+
"""
417+
418+
if image_uri:
419+
return image_uri
420+
421+
base_framework_version: Optional[str] = None
422+
423+
if tensorflow_version is not None or pytorch_version is not None:
424+
processor = _processor(instance_type, ["cpu", "gpu"])
425+
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
426+
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
427+
if tensorflow_version is not None:
428+
base_framework_version = f"tensorflow{tensorflow_version}"
429+
else:
430+
base_framework_version = f"pytorch{pytorch_version}"
431+
else:
432+
container_version = None
433+
base_framework_version = None
434+
435+
return retrieve(
436+
framework,
437+
region,
438+
instance_type=instance_type,
439+
version=framework_version,
440+
py_version=py_version,
441+
image_scope="training",
442+
distribution=distribution,
443+
base_framework_version=base_framework_version,
444+
container_version=container_version,
445+
training_compiler_config=compiler_config,
446+
)

src/sagemaker/jumpstart/accessors.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains accessors related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
from typing import Any, Dict, Optional
16+
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
17+
from sagemaker.jumpstart import cache
18+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
19+
20+
21+
class SageMakerSettings(object):
22+
"""Static class for storing the SageMaker settings."""
23+
24+
_parsed_sagemaker_version = ""
25+
26+
@staticmethod
27+
def set_sagemaker_version(version: str) -> None:
28+
"""Set SageMaker version."""
29+
SageMakerSettings._parsed_sagemaker_version = version
30+
31+
@staticmethod
32+
def get_sagemaker_version() -> str:
33+
"""Return SageMaker version."""
34+
return SageMakerSettings._parsed_sagemaker_version
35+
36+
37+
class JumpStartModelsAccessor(object):
38+
"""Static class for storing the JumpStart models cache."""
39+
40+
_cache: Optional[cache.JumpStartModelsCache] = None
41+
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
42+
43+
_cache_kwargs: Dict[str, Any] = {}
44+
45+
@staticmethod
46+
def _validate_and_mutate_region_cache_kwargs(
47+
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
48+
) -> Dict[str, Any]:
49+
"""Returns cache_kwargs with region argument removed if present.
50+
51+
Raises:
52+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
53+
54+
Args:
55+
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
56+
region (str): The region to validate along with the kwargs.
57+
"""
58+
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
59+
assert isinstance(cache_kwargs_dict, dict)
60+
if region is not None and "region" in cache_kwargs_dict:
61+
if region != cache_kwargs_dict["region"]:
62+
raise ValueError(
63+
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
64+
)
65+
del cache_kwargs_dict["region"]
66+
return cache_kwargs_dict
67+
68+
@staticmethod
69+
def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
70+
"""Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``.
71+
72+
Args:
73+
region (str): region for which to retrieve header/spec.
74+
cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
75+
"""
76+
if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region:
77+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
78+
region=region, **cache_kwargs
79+
)
80+
JumpStartModelsAccessor._curr_region = region
81+
82+
@staticmethod
83+
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
84+
"""Returns model header from JumpStart models cache.
85+
86+
Args:
87+
region (str): region for which to retrieve header.
88+
model_id (str): model id to retrieve.
89+
version (str): semantic version to retrieve for the model id.
90+
"""
91+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
92+
JumpStartModelsAccessor._cache_kwargs, region
93+
)
94+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95+
assert JumpStartModelsAccessor._cache is not None
96+
return JumpStartModelsAccessor._cache.get_header(model_id, version)
97+
98+
@staticmethod
99+
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
100+
"""Returns model specs from JumpStart models cache.
101+
102+
Args:
103+
region (str): region for which to retrieve header.
104+
model_id (str): model id to retrieve.
105+
version (str): semantic version to retrieve for the model id.
106+
"""
107+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
108+
JumpStartModelsAccessor._cache_kwargs, region
109+
)
110+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
111+
assert JumpStartModelsAccessor._cache is not None
112+
return JumpStartModelsAccessor._cache.get_specs(model_id, version)
113+
114+
@staticmethod
115+
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
116+
"""Sets cache kwargs, clears the cache.
117+
118+
Raises:
119+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
120+
121+
Args:
122+
cache_kwargs (str): cache kwargs to validate.
123+
region (str): Optional. The region to validate along with the kwargs.
124+
"""
125+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
126+
cache_kwargs, region
127+
)
128+
JumpStartModelsAccessor._cache_kwargs = cache_kwargs
129+
if region is None:
130+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
131+
**JumpStartModelsAccessor._cache_kwargs
132+
)
133+
else:
134+
JumpStartModelsAccessor._curr_region = region
135+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
136+
region=region, **JumpStartModelsAccessor._cache_kwargs
137+
)
138+
139+
@staticmethod
140+
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
141+
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.
142+
143+
Raises:
144+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
145+
146+
Args:
147+
cache_kwargs (str): cache kwargs to validate.
148+
region (str): The region to validate along with the kwargs.
149+
"""
150+
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
151+
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

0 commit comments

Comments
 (0)