Skip to content

Commit ed7b772

Browse files
committed
change: improve jumpstart retrieve fx impl, cleanup tests, comments, and code
1 parent f85b7f0 commit ed7b772

File tree

13 files changed

+417
-141
lines changed

13 files changed

+417
-141
lines changed

src/sagemaker/image_uris.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import re
2020

2121
from sagemaker import utils
22+
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2223
from sagemaker.spark import defaults
23-
from sagemaker.jumpstart import accessors as jumpstart_accessors
24+
from sagemaker.jumpstart import artifacts
25+
2426

2527
logger = logging.getLogger(__name__)
2628

@@ -81,45 +83,23 @@ def retrieve(
8183
Raises:
8284
ValueError: If the combination of arguments specified is not supported.
8385
"""
84-
if model_id is not None or model_version is not None:
85-
if model_id is None or model_version is None:
86-
raise ValueError(
87-
"Must specify `model_id` and `model_version` when getting image uri for "
88-
"JumpStart models. "
89-
)
90-
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
91-
region, model_id, model_version
92-
)
93-
if image_scope is None:
94-
raise ValueError(
95-
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
96-
)
97-
if image_scope == "inference":
98-
ecr_specs = model_specs.hosting_ecr_specs
99-
elif image_scope == "training":
100-
if not model_specs.training_supported:
101-
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
102-
ecr_specs = model_specs.training_ecr_specs
103-
else:
104-
raise ValueError("JumpStart models only support inference and training.")
105-
106-
if framework is not None and framework != ecr_specs.framework:
107-
raise ValueError(
108-
f"Bad value for container framework for JumpStart model: '{framework}'."
109-
)
110-
111-
return retrieve(
112-
framework=ecr_specs.framework,
113-
region=region,
114-
version=ecr_specs.framework_version,
115-
py_version=ecr_specs.py_version,
116-
instance_type=instance_type,
117-
accelerator_type=accelerator_type,
118-
image_scope=image_scope,
119-
container_version=container_version,
120-
distribution=distribution,
121-
base_framework_version=base_framework_version,
122-
training_compiler_config=training_compiler_config,
86+
if is_jumpstart_model_input(model_id, model_version):
87+
assert model_id is not None
88+
assert model_version is not None
89+
return artifacts._retrieve_image_uri(
90+
model_id,
91+
model_version,
92+
framework,
93+
region,
94+
version,
95+
py_version,
96+
instance_type,
97+
accelerator_type,
98+
image_scope,
99+
container_version,
100+
distribution,
101+
base_framework_version,
102+
training_compiler_config,
123103
)
124104

125105
if training_compiler_config is None:

src/sagemaker/jumpstart/accessors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
103103

104104
@staticmethod
105105
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
106-
"""Sets cache kwargs. Clears the cache.
106+
"""Sets cache kwargs, clear the cache.
107107
108108
Raises:
109109
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
110110
111111
Args:
112112
cache_kwargs (str): cache kwargs to validate.
113-
region (str): The region to validate along with the kwargs.
113+
region (str): Optional. The region to validate along with the kwargs.
114114
"""
115115
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
116116
JumpStartModelsCache._cache_kwargs = cache_kwargs

src/sagemaker/jumpstart/artifacts.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functions for obtainining JumpStart artifacts."""
14+
from __future__ import absolute_import
15+
from typing import Optional
16+
from sagemaker import image_uris
17+
from sagemaker.jumpstart.constants import (
18+
JUMPSTART_DEFAULT_REGION_NAME,
19+
INFERENCE,
20+
TRAINING,
21+
SUPPORTED_JUMPSTART_SCOPES,
22+
)
23+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.jumpstart import accessors as jumpstart_accessors
25+
26+
27+
def _retrieve_image_uri(
28+
model_id: str,
29+
model_version: str,
30+
framework: Optional[str],
31+
region: Optional[str],
32+
version: Optional[str],
33+
py_version: Optional[str],
34+
instance_type: Optional[str],
35+
accelerator_type: Optional[str],
36+
image_scope: Optional[str],
37+
container_version: Optional[str],
38+
distribution: Optional[str],
39+
base_framework_version: Optional[str],
40+
training_compiler_config: Optional[str],
41+
):
42+
"""Retrieves the container image URI for JumpStart models.
43+
44+
Only `model_id` and `model_version` are required to be non-None;
45+
the rest of the fields are auto-populated.
46+
47+
48+
Args:
49+
model_id (str): JumpStart model id for which to retrieve image URI.
50+
model_version (str): JumpStart model version for which to retrieve image URI.
51+
framework (str): The name of the framework or algorithm.
52+
region (str): The AWS region.
53+
version (str): The framework or algorithm version. This is required if there is
54+
more than one supported version for the given framework or algorithm.
55+
py_version (str): The Python version. This is required if there is
56+
more than one supported Python version for the given framework version.
57+
instance_type (str): The SageMaker instance type. For supported types, see
58+
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
59+
there are different images for different processor types.
60+
accelerator_type (str): Elastic Inference accelerator type. For more, see
61+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
62+
image_scope (str): The image type, i.e. what it is used for.
63+
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
64+
``image_scope`` is ignored.
65+
container_version (str): the version of docker image.
66+
Ideally the value of parameter should be created inside the framework.
67+
For custom use, see the list of supported container versions:
68+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
69+
(default: None).
70+
distribution (dict): A dictionary with information on how to run distributed training
71+
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
72+
A configuration class for the SageMaker Training Compiler
73+
(default: None).
74+
75+
Returns:
76+
str: the ECR URI for the corresponding SageMaker Docker image.
77+
78+
Raises:
79+
ValueError: If the combination of arguments specified is not supported.
80+
"""
81+
if region is None:
82+
region = JUMPSTART_DEFAULT_REGION_NAME
83+
84+
assert region is not None
85+
86+
if image_scope is None:
87+
raise ValueError(
88+
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
89+
)
90+
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
91+
raise ValueError("JumpStart models only support inference and training.")
92+
93+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
94+
region, model_id, model_version
95+
)
96+
97+
if image_scope == INFERENCE:
98+
ecr_specs = model_specs.hosting_ecr_specs
99+
elif image_scope == TRAINING:
100+
if not model_specs.training_supported:
101+
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
102+
assert model_specs.training_ecr_specs is not None
103+
ecr_specs = model_specs.training_ecr_specs
104+
105+
if framework is not None and framework != ecr_specs.framework:
106+
raise ValueError(f"Bad value for container framework for JumpStart model: '{framework}'.")
107+
108+
if version is not None and version != ecr_specs.framework_version:
109+
raise ValueError(
110+
f"Bad value for container framework version for JumpStart model: '{version}'."
111+
)
112+
113+
if py_version is not None and py_version != ecr_specs.py_version:
114+
raise ValueError(
115+
f"Bad value for container python version for JumpStart model: '{py_version}'."
116+
)
117+
118+
if framework == "huggingface":
119+
base_framework_version = ecr_specs.framework_version
120+
121+
return image_uris.retrieve(
122+
framework=ecr_specs.framework,
123+
region=region,
124+
version=ecr_specs.framework_version,
125+
py_version=ecr_specs.py_version,
126+
instance_type=instance_type,
127+
accelerator_type=accelerator_type,
128+
image_scope=image_scope,
129+
container_version=container_version,
130+
distribution=distribution,
131+
base_framework_version=base_framework_version,
132+
training_compiler_config=training_compiler_config,
133+
)
134+
135+
136+
def _retrieve_model_uri(
137+
model_id: str,
138+
model_version: str,
139+
model_scope: Optional[str],
140+
region: Optional[str],
141+
):
142+
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
143+
144+
Args:
145+
model_id (str): JumpStart model id for which to retrieve model S3 URI.
146+
model_version (str): JumpStart model version for which to retrieve model S3 URI.
147+
model_scope (str): The model type, i.e. what it is used for.
148+
Valid values: "training" and "inference".
149+
region (str): Region for which to retrieve model S3 URI.
150+
Returns:
151+
str: the model artifact S3 URI for the corresponding model.
152+
153+
Raises:
154+
ValueError: If the combination of arguments specified is not supported.
155+
"""
156+
if region is None:
157+
region = JUMPSTART_DEFAULT_REGION_NAME
158+
159+
assert region is not None
160+
161+
if model_scope is None:
162+
raise ValueError(
163+
"Must specify `model_scope` argument to retrieve model "
164+
"artifact uri for JumpStart models."
165+
)
166+
167+
if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
168+
raise ValueError("JumpStart models only support inference and training.")
169+
170+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
171+
region, model_id, model_version
172+
)
173+
if model_scope == INFERENCE:
174+
model_artifact_key = model_specs.hosting_artifact_key
175+
elif model_scope == TRAINING:
176+
if not model_specs.training_supported:
177+
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
178+
assert model_specs.training_artifact_key is not None
179+
model_artifact_key = model_specs.training_artifact_key
180+
181+
bucket = get_jumpstart_content_bucket(region)
182+
183+
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
184+
185+
return model_s3_uri
186+
187+
188+
def _retrieve_script_uri(
189+
model_id: str,
190+
model_version: str,
191+
script_scope: Optional[str],
192+
region: Optional[str],
193+
):
194+
"""Retrieves the model script s3 URI for the model matching the given arguments.
195+
196+
Args:
197+
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
198+
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
199+
script_scope (str): The script type, i.e. what it is used for.
200+
Valid values: "training" and "inference".
201+
region (str): Region for which to retrieve model script S3 URI.
202+
Returns:
203+
str: the model script URI for the corresponding model.
204+
205+
Raises:
206+
ValueError: If the combination of arguments specified is not supported.
207+
"""
208+
if region is None:
209+
region = JUMPSTART_DEFAULT_REGION_NAME
210+
211+
assert region is not None
212+
213+
if script_scope is None:
214+
raise ValueError(
215+
"Must specify `script_scope` argument to retrieve model script uri for "
216+
"JumpStart models."
217+
)
218+
219+
if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
220+
raise ValueError("JumpStart models only support inference and training.")
221+
222+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
223+
region, model_id, model_version
224+
)
225+
if script_scope == INFERENCE:
226+
model_script_key = model_specs.hosting_script_key
227+
elif script_scope == TRAINING:
228+
if not model_specs.training_supported:
229+
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
230+
assert model_specs.training_script_key is not None
231+
model_script_key = model_specs.training_script_key
232+
233+
bucket = get_jumpstart_content_bucket(region)
234+
235+
script_s3_uri = f"s3://{bucket}/{model_script_key}"
236+
237+
return script_s3_uri

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _select_version(
282282
semantic_version_str: str,
283283
available_versions: List[Version],
284284
) -> Optional[str]:
285-
"""Utility to select appropriate version from available versions.
285+
"""Perform semantic version search on available versions.
286286
287287
Args:
288288
semantic_version_str (str): the semantic version for which to filter

src/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@
103103
region_name="cn-north-1",
104104
content_bucket="jumpstart-cache-prod-cn-north-1",
105105
),
106-
JumpStartLaunchedRegionInfo(
107-
region_name="cn-northwest-1",
108-
content_bucket="jumpstart-cache-prod-cn-northwest-1",
109-
),
110106
]
111107
)
112108

@@ -118,3 +114,7 @@
118114
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
119115

120116
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
117+
118+
INFERENCE = "inference"
119+
TRAINING = "training"
120+
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])

src/sagemaker/jumpstart/utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15-
from typing import Dict, List
15+
from typing import Dict, List, Optional
1616
from packaging.version import Version
1717
import sagemaker
1818
from sagemaker.jumpstart import constants
@@ -113,3 +113,26 @@ def parse_sagemaker_version() -> str:
113113
Version(parsed_version)
114114

115115
return parsed_version
116+
117+
118+
def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> bool:
119+
"""Determines if `model_id` and `version` input are for JumpStart.
120+
121+
This method returns True if both arguments are not None, false if both arguments
122+
are None, and raises an exception if one argument is None but the other isn't.
123+
124+
Args:
125+
model_id (str): Optional. Model id of JumpStart model.
126+
version (str): Optional. Version for JumpStart model.
127+
128+
Raises:
129+
ValueError: If only one of the two arguments is None.
130+
"""
131+
if model_id is not None or version is not None:
132+
if model_id is None or version is None:
133+
raise ValueError(
134+
"Must specify `model_id` and `model_version` when getting specs for "
135+
"JumpStart models."
136+
)
137+
return True
138+
return False

0 commit comments

Comments
 (0)