Skip to content

Commit a427d4b

Browse files
committed
change: improve jumpstart retrieve uri unit tests, fix logic for image uris
1 parent ed7b772 commit a427d4b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1954
-201
lines changed

src/sagemaker/estimator.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,51 +2425,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
24252425

24262426
return init_params
24272427

2428-
def training_image_uri(self):
2428+
def training_image_uri(self, region=None):
24292429
"""Return the Docker image to use for training.
24302430
24312431
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
24322432
the model training, calls this method to find the image to use for model
24332433
training.
24342434
2435+
Args:
2436+
region: Region to use for image uri.
2437+
Default: Region associated with SageMaker session.
2438+
24352439
Returns:
24362440
str: The URI of the Docker image.
24372441
"""
2438-
if self.image_uri:
2439-
return self.image_uri
2440-
if hasattr(self, "distribution"):
2441-
distribution = self.distribution # pylint: disable=no-member
2442-
else:
2443-
distribution = None
2444-
compiler_config = getattr(self, "compiler_config", None)
2445-
2446-
if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
2447-
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
2448-
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
2449-
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
2450-
if self.tensorflow_version is not None: # pylint: disable=no-member
2451-
base_framework_version = (
2452-
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
2453-
)
2454-
else:
2455-
base_framework_version = (
2456-
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
2457-
)
2458-
else:
2459-
container_version = None
2460-
base_framework_version = None
24612442

2462-
return image_uris.retrieve(
2463-
self._framework_name,
2464-
self.sagemaker_session.boto_region_name,
2465-
instance_type=self.instance_type,
2466-
version=self.framework_version, # pylint: disable=no-member
2443+
return image_uris.get_training_image_uri(
2444+
region=region or self.sagemaker_session.boto_region_name,
2445+
framework=self._framework_name,
2446+
framework_version=self.framework_version, # pylint: disable=no-member
24672447
py_version=self.py_version, # pylint: disable=no-member
2468-
image_scope="training",
2469-
distribution=distribution,
2470-
base_framework_version=base_framework_version,
2471-
container_version=container_version,
2472-
training_compiler_config=compiler_config,
2448+
image_uri=self.image_uri,
2449+
distribution=getattr(self, "distribution", None),
2450+
compiler_config=getattr(self, "compiler_config", None),
2451+
tensorflow_version=getattr(self, "tensorflow_version", None),
2452+
pytorch_version=getattr(self, "pytorch_version", None),
2453+
instance_type=self.instance_type,
24732454
)
24742455

24752456
@classmethod

src/sagemaker/image_uris.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import os
1919
import re
20+
from typing import Optional
2021

2122
from sagemaker import utils
2223
from sagemaker.jumpstart.utils import is_jumpstart_model_input
@@ -373,3 +374,68 @@ def _validate_arg(arg, available_options, arg_name):
373374
def _format_tag(tag_prefix, processor, py_version, container_version):
374375
"""Creates a tag for the image URI."""
375376
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
377+
378+
379+
def get_training_image_uri(
380+
region,
381+
framework,
382+
framework_version=None,
383+
py_version=None,
384+
image_uri=None,
385+
distribution=None,
386+
compiler_config=None,
387+
tensorflow_version=None,
388+
pytorch_version=None,
389+
instance_type=None,
390+
) -> str:
391+
"""Retrieve image uri for training.
392+
393+
Args:
394+
region (str): AWS region to use for image URI.
395+
framework (str): The framework for which to retrieve an image URI.
396+
framework_version (str): The framework version for which to retrieve an
397+
image URI (default: None).
398+
py_version (str): The python version to use for the image (default: None).
399+
image_uri (str): If an image URI is supplied, it will be returned (default: None).
400+
distribution (dict): A dictionary with information on how to run distributed
401+
training (default: None).
402+
compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
403+
A configuration class for the SageMaker Training Compiler
404+
(default: None).
405+
tensorflow_version (str): Version of tensorflow to use. (default: None)
406+
pytorch_version (str): Version of pytorch to use. (default: None)
407+
instance_type (str): Instance type fo use. (default: None)
408+
409+
Returns:
410+
str: the image URI string.
411+
"""
412+
413+
if image_uri:
414+
return image_uri
415+
416+
base_framework_version: Optional[str] = None
417+
418+
if tensorflow_version is not None or pytorch_version is not None:
419+
processor = _processor(instance_type, ["cpu", "gpu"])
420+
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
421+
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
422+
if tensorflow_version is not None:
423+
base_framework_version = f"tensorflow{tensorflow_version}"
424+
else:
425+
base_framework_version = f"pytorch{pytorch_version}"
426+
else:
427+
container_version = None
428+
base_framework_version = None
429+
430+
return retrieve(
431+
framework,
432+
region,
433+
instance_type=instance_type,
434+
version=framework_version,
435+
py_version=py_version,
436+
image_scope="training",
437+
distribution=distribution,
438+
base_framework_version=base_framework_version,
439+
container_version=container_version,
440+
training_compiler_config=compiler_config,
441+
)

src/sagemaker/jumpstart/artifacts.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
INFERENCE,
2020
TRAINING,
2121
SUPPORTED_JUMPSTART_SCOPES,
22+
ModelFramework,
2223
)
2324
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2425
from sagemaker.jumpstart import accessors as jumpstart_accessors
@@ -115,20 +116,39 @@ def _retrieve_image_uri(
115116
f"Bad value for container python version for JumpStart model: '{py_version}'."
116117
)
117118

118-
if framework == "huggingface":
119-
base_framework_version = ecr_specs.framework_version
119+
base_framework_version_override = None
120+
version_override = None
121+
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
122+
base_framework_version_override = ecr_specs.framework_version
123+
version_override = ecr_specs.huggingface_transformers_version
124+
125+
if image_scope == TRAINING:
126+
return image_uris.get_training_image_uri(
127+
region=region,
128+
framework=ecr_specs.framework,
129+
framework_version=version_override or ecr_specs.framework_version,
130+
py_version=ecr_specs.py_version,
131+
image_uri=None,
132+
distribution=None,
133+
compiler_config=None,
134+
tensorflow_version=None,
135+
pytorch_version=base_framework_version_override or base_framework_version,
136+
instance_type=instance_type,
137+
)
138+
if base_framework_version_override is not None:
139+
base_framework_version_override = f"pytorch{base_framework_version_override}"
120140

121141
return image_uris.retrieve(
122142
framework=ecr_specs.framework,
123143
region=region,
124-
version=ecr_specs.framework_version,
144+
version=version_override or ecr_specs.framework_version,
125145
py_version=ecr_specs.py_version,
126146
instance_type=instance_type,
127147
accelerator_type=accelerator_type,
128148
image_scope=image_scope,
129149
container_version=container_version,
130150
distribution=distribution,
131-
base_framework_version=base_framework_version,
151+
base_framework_version=base_framework_version_override or base_framework_version,
132152
training_compiler_config=training_compiler_config,
133153
)
134154

src/sagemaker/jumpstart/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Set
16+
from enum import Enum
1617
import boto3
1718
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1819

@@ -118,3 +119,20 @@
118119
INFERENCE = "inference"
119120
TRAINING = "training"
120121
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])
122+
123+
124+
class ModelFramework(str, Enum):
125+
"""Enum class for JumpStart model framework.
126+
127+
The ML framework as referenced in the prefix of the model ID.
128+
This value does not necessarily correspond to the container name.
129+
"""
130+
131+
PYTORCH = "pytorch"
132+
TENSORFLOW = "tensorflow"
133+
MXNET = "mxnet"
134+
HUGGINGFACE = "huggingface"
135+
LIGHTGBM = "lightgbm"
136+
CATBOOST = "catboost"
137+
XGBOOST = "xgboost"
138+
SKLEARN = "sklearn"

src/sagemaker/jumpstart/types.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@ def __eq__(self, other: Any) -> bool:
4141
if self.__slots__ != other.__slots__:
4242
return False
4343
for attribute in self.__slots__:
44-
if getattr(self, attribute) != getattr(other, attribute):
44+
if (hasattr(self, attribute) and not hasattr(other, attribute)) or (
45+
hasattr(other, attribute) and not hasattr(self, attribute)
46+
):
4547
return False
48+
if hasattr(self, attribute) and hasattr(other, attribute):
49+
if getattr(self, attribute) != getattr(other, attribute):
50+
return False
4651
return True
4752

4853
def __hash__(self) -> int:
@@ -112,7 +117,7 @@ def __init__(self, header: Dict[str, str]):
112117

113118
def to_json(self) -> Dict[str, str]:
114119
"""Returns json representation of JumpStartModelHeader object."""
115-
json_obj = {att: getattr(self, att) for att in self.__slots__}
120+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
116121
return json_obj
117122

118123
def from_json(self, json_obj: Dict[str, str]) -> None:
@@ -134,6 +139,7 @@ class JumpStartECRSpecs(JumpStartDataHolderType):
134139
"framework",
135140
"framework_version",
136141
"py_version",
142+
"huggingface_transformers_version",
137143
}
138144

139145
def __init__(self, spec: Dict[str, Any]):
@@ -154,10 +160,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
154160
self.framework = json_obj["framework"]
155161
self.framework_version = json_obj["framework_version"]
156162
self.py_version = json_obj["py_version"]
163+
huggingface_transformers_version = json_obj.get("huggingface_transformers_version")
164+
if huggingface_transformers_version is not None:
165+
self.huggingface_transformers_version = huggingface_transformers_version
157166

158167
def to_json(self) -> Dict[str, Any]:
159168
"""Returns json representation of JumpStartECRSpecs object."""
160-
json_obj = {att: getattr(self, att) for att in self.__slots__}
169+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
161170
return json_obj
162171

163172

@@ -202,26 +211,23 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
202211
self.hosting_script_key: str = json_obj["hosting_script_key"]
203212
self.training_supported: bool = bool(json_obj["training_supported"])
204213
if self.training_supported:
205-
self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs(
214+
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
206215
json_obj["training_ecr_specs"]
207216
)
208-
self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"]
209-
self.training_script_key: Optional[str] = json_obj["training_script_key"]
210-
self.hyperparameters: Optional[Dict[str, Any]] = json_obj.get("hyperparameters")
211-
else:
212-
self.training_ecr_specs = (
213-
self.training_artifact_key
214-
) = self.training_script_key = self.hyperparameters = None
217+
self.training_artifact_key: str = json_obj["training_artifact_key"]
218+
self.training_script_key: str = json_obj["training_script_key"]
219+
self.hyperparameters: Dict[str, Any] = json_obj.get("hyperparameters", {})
215220

216221
def to_json(self) -> Dict[str, Any]:
217222
"""Returns json representation of JumpStartModelSpecs object."""
218223
json_obj = {}
219224
for att in self.__slots__:
220-
cur_val = getattr(self, att)
221-
if isinstance(cur_val, JumpStartECRSpecs):
222-
json_obj[att] = cur_val.to_json()
223-
else:
224-
json_obj[att] = cur_val
225+
if hasattr(self, att):
226+
cur_val = getattr(self, att)
227+
if isinstance(cur_val, JumpStartECRSpecs):
228+
json_obj[att] = cur_val.to_json()
229+
else:
230+
json_obj[att] = cur_val
225231
return json_obj
226232

227233

src/sagemaker/script_uris.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def retrieve(
4343
ValueError: If the combination of arguments specified is not supported.
4444
"""
4545
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
46-
raise ValueError(
47-
"Must specify `model_id` and `model_version` when retrieving script URIs."
48-
)
46+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
4947

5048
assert model_id is not None
5149
assert model_version is not None

src/sagemaker/sklearn/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
source_dir=None,
4444
hyperparameters=None,
4545
image_uri=None,
46+
image_uri_region=None,
4647
**kwargs
4748
):
4849
"""Creates a SKLearn Estimator for Scikit-learn environment.
@@ -99,6 +100,9 @@ def __init__(
99100
If ``framework_version`` or ``py_version`` are ``None``, then
100101
``image_uri`` is required. If also ``None``, then a ``ValueError``
101102
will be raised.
103+
image_uri_region (str): If ``image_uri` argument is None, the image uri
104+
associated with this object will be in this region.
105+
Default: region associated with SageMaker session.
102106
**kwargs: Additional kwargs passed to the
103107
:class:`~sagemaker.estimator.Framework` constructor.
104108
@@ -144,7 +148,7 @@ def __init__(
144148
if image_uri is None:
145149
self.image_uri = image_uris.retrieve(
146150
SKLearn._framework_name,
147-
self.sagemaker_session.boto_region_name,
151+
image_uri_region or self.sagemaker_session.boto_region_name,
148152
version=self.framework_version,
149153
py_version=self.py_version,
150154
instance_type=instance_type,

src/sagemaker/tensorflow/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,14 @@ def _get_container_env(self):
352352
env[self.LOG_LEVEL_PARAM_NAME] = self.LOG_LEVEL_MAP[self._container_log_level]
353353
return env
354354

355-
def _get_image_uri(self, instance_type, accelerator_type=None):
355+
def _get_image_uri(self, instance_type, accelerator_type=None, region_name=None):
356356
"""Placeholder docstring."""
357357
if self.image_uri:
358358
return self.image_uri
359359

360360
return image_uris.retrieve(
361361
self._framework_name,
362-
self.sagemaker_session.boto_region_name,
362+
region_name or self.sagemaker_session.boto_region_name,
363363
version=self.framework_version,
364364
instance_type=instance_type,
365365
accelerator_type=accelerator_type,
@@ -383,4 +383,6 @@ def serving_image_uri(
383383
str: The appropriate image URI based on the given parameters.
384384
385385
"""
386-
return self._get_image_uri(instance_type=instance_type, accelerator_type=accelerator_type)
386+
return self._get_image_uri(
387+
instance_type=instance_type, accelerator_type=accelerator_type, region_name=region_name
388+
)

src/sagemaker/xgboost/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
hyperparameters=None,
4949
py_version="py3",
5050
image_uri=None,
51+
image_uri_region=None,
5152
**kwargs
5253
):
5354
"""An estimator that executes an XGBoost-based SageMaker Training Job.
@@ -89,6 +90,9 @@ def __init__(
8990
Examples:
9091
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
9192
custom-image:latest.
93+
image_uri_region (str): If ``image_uri` argument is None, the image uri
94+
associated with this object will be in this region.
95+
Default: region associated with SageMaker session.
9296
**kwargs: Additional kwargs passed to the
9397
:class:`~sagemaker.estimator.Framework` constructor.
9498
@@ -114,7 +118,7 @@ def __init__(
114118
if image_uri is None:
115119
self.image_uri = image_uris.retrieve(
116120
self._framework_name,
117-
self.sagemaker_session.boto_region_name,
121+
image_uri_region or self.sagemaker_session.boto_region_name,
118122
version=framework_version,
119123
py_version=self.py_version,
120124
instance_type=instance_type,

tests/unit/sagemaker/image_uris/jumpstart/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)