10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """This module contains functions for obtainining JumpStart artifacts ."""
13
+ """This module contains functions for obtaining JumpStart ECR and S3 URIs ."""
14
14
from __future__ import absolute_import
15
15
from typing import Optional
16
16
from sagemaker import image_uris
@@ -42,13 +42,14 @@ def _retrieve_image_uri(
42
42
):
43
43
"""Retrieves the container image URI for JumpStart models.
44
44
45
- Only `model_id` and `model_version ` are required to be non-None ;
45
+ Only `model_id`, `model_version`, and `image_scope ` are required;
46
46
the rest of the fields are auto-populated.
47
47
48
48
49
49
Args:
50
- model_id (str): JumpStart model id for which to retrieve image URI.
51
- model_version (str): JumpStart model version for which to retrieve image URI.
50
+ model_id (str): JumpStart model ID for which to retrieve image URI.
51
+ model_version (str): Version of the JumpStart model for which to retrieve
52
+ the image URI (default: None).
52
53
framework (str): The name of the framework or algorithm.
53
54
region (str): The AWS region.
54
55
version (str): The framework or algorithm version. This is required if there is
@@ -89,7 +90,9 @@ def _retrieve_image_uri(
89
90
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90
91
)
91
92
if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92
- raise ValueError ("JumpStart models only support inference and training." )
93
+ raise ValueError (
94
+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
95
+ )
93
96
94
97
model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
95
98
region , model_id , model_version
@@ -99,25 +102,33 @@ def _retrieve_image_uri(
99
102
ecr_specs = model_specs .hosting_ecr_specs
100
103
elif image_scope == TRAINING :
101
104
if not model_specs .training_supported :
102
- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
105
+ raise ValueError (
106
+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
107
+ "does not support training."
108
+ )
103
109
assert model_specs .training_ecr_specs is not None
104
110
ecr_specs = model_specs .training_ecr_specs
105
111
106
112
if framework is not None and framework != ecr_specs .framework :
107
- raise ValueError (f"Bad value for container framework for JumpStart model: '{ framework } '." )
113
+ raise ValueError (
114
+ f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
115
+ "and version {model_version}'."
116
+ )
108
117
109
118
if version is not None and version != ecr_specs .framework_version :
110
119
raise ValueError (
111
- f"Bad value for container framework version for JumpStart model: '{ version } '."
120
+ f"Incorrect container framework version '{ version } ' for JumpStart model ID "
121
+ f"'{ model_id } ' and version { model_version } '."
112
122
)
113
123
114
124
if py_version is not None and py_version != ecr_specs .py_version :
115
125
raise ValueError (
116
- f"Bad value for container python version for JumpStart model: '{ py_version } '."
126
+ f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
127
+ "and version {model_version}'."
117
128
)
118
129
119
- base_framework_version_override = None
120
- version_override = None
130
+ base_framework_version_override : Optional [ str ] = None
131
+ version_override : Optional [ str ] = None
121
132
if ecr_specs .framework == ModelFramework .HUGGINGFACE .value :
122
133
base_framework_version_override = ecr_specs .framework_version
123
134
version_override = ecr_specs .huggingface_transformers_version
@@ -162,8 +173,10 @@ def _retrieve_model_uri(
162
173
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
163
174
164
175
Args:
165
- model_id (str): JumpStart model id for which to retrieve model S3 URI.
166
- model_version (str): JumpStart model version for which to retrieve model S3 URI.
176
+ model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
177
+ the model artifact S3 URI.
178
+ model_version (str): Version of the JumpStart model for which to retrieve the model
179
+ artifact S3 URI.
167
180
model_scope (str): The model type, i.e. what it is used for.
168
181
Valid values: "training" and "inference".
169
182
region (str): Region for which to retrieve model S3 URI.
@@ -185,7 +198,9 @@ def _retrieve_model_uri(
185
198
)
186
199
187
200
if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
188
- raise ValueError ("JumpStart models only support inference and training." )
201
+ raise ValueError (
202
+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
203
+ )
189
204
190
205
model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
191
206
region , model_id , model_version
@@ -194,7 +209,10 @@ def _retrieve_model_uri(
194
209
model_artifact_key = model_specs .hosting_artifact_key
195
210
elif model_scope == TRAINING :
196
211
if not model_specs .training_supported :
197
- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
212
+ raise ValueError (
213
+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
214
+ "does not support training."
215
+ )
198
216
assert model_specs .training_artifact_key is not None
199
217
model_artifact_key = model_specs .training_artifact_key
200
218
@@ -211,11 +229,13 @@ def _retrieve_script_uri(
211
229
script_scope : Optional [str ],
212
230
region : Optional [str ],
213
231
):
214
- """Retrieves the model script s3 URI for the model matching the given arguments.
232
+ """Retrieves the script S3 URI associated with the model matching the given arguments.
215
233
216
234
Args:
217
- model_id (str): JumpStart model id for which to retrieve model script S3 URI.
218
- model_version (str): JumpStart model version for which to retrieve model script S3 URI.
235
+ model_id (str): JumpStart model ID of the JumpStart model for which to
236
+ retrieve the script S3 URI.
237
+ model_version (str): Version of the JumpStart model for which to
238
+ retrieve the model script S3 URI.
219
239
script_scope (str): The script type, i.e. what it is used for.
220
240
Valid values: "training" and "inference".
221
241
region (str): Region for which to retrieve model script S3 URI.
@@ -237,7 +257,9 @@ def _retrieve_script_uri(
237
257
)
238
258
239
259
if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
240
- raise ValueError ("JumpStart models only support inference and training." )
260
+ raise ValueError (
261
+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
262
+ )
241
263
242
264
model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
243
265
region , model_id , model_version
@@ -246,7 +268,10 @@ def _retrieve_script_uri(
246
268
model_script_key = model_specs .hosting_script_key
247
269
elif script_scope == TRAINING :
248
270
if not model_specs .training_supported :
249
- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
271
+ raise ValueError (
272
+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
273
+ "does not support training."
274
+ )
250
275
assert model_specs .training_script_key is not None
251
276
model_script_key = model_specs .training_script_key
252
277
0 commit comments