24
24
from sagemaker .spark import defaults
25
25
from sagemaker .jumpstart import artifacts
26
26
27
-
28
27
logger = logging .getLogger (__name__ )
29
28
30
29
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
@@ -100,7 +99,6 @@ def retrieve(
100
99
DeprecatedJumpStartModelError: If the version of the model is deprecated.
101
100
"""
102
101
if is_jumpstart_model_input (model_id , model_version ):
103
-
104
102
return artifacts ._retrieve_image_uri (
105
103
model_id ,
106
104
model_version ,
@@ -118,17 +116,22 @@ def retrieve(
118
116
tolerate_vulnerable_model ,
119
117
tolerate_deprecated_model ,
120
118
)
121
-
122
119
if training_compiler_config is None :
123
- config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
120
+ if framework == HUGGING_FACE_FRAMEWORK and instance_type == "neuron" :
121
+ config = _config_for_framework_and_scope (
122
+ framework + "-neuron" , image_scope , accelerator_type
123
+ )
124
+ else :
125
+ config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
124
126
elif framework == HUGGING_FACE_FRAMEWORK :
125
127
config = _config_for_framework_and_scope (
126
- framework + "-training-compiler" , image_scope , accelerator_type
128
+ framework + "-training-compiler" , image_scope , accelerator_type , instance_type
127
129
)
128
130
else :
129
131
raise ValueError (
130
132
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
131
133
)
134
+
132
135
original_version = version
133
136
version = _validate_version_and_set_if_needed (version , config , framework )
134
137
version_config = config ["versions" ][_version_for_config (version , config )]
@@ -169,6 +172,9 @@ def retrieve(
169
172
]:
170
173
_version = version
171
174
if processor == "neuron" :
175
+ sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
176
+ repo_versions = _get_latest_versions (version_config ["repo_versions" ])
177
+ container_version = sdk_version + "-" + container_version + "-" + repo_versions
172
178
repo += "-{0}" .format (processor )
173
179
174
180
tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -208,8 +214,12 @@ def retrieve(
208
214
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
209
215
210
216
211
- def _config_for_framework_and_scope (framework , image_scope , accelerator_type = None ):
217
+ def _config_for_framework_and_scope (
218
+ framework , image_scope , accelerator_type = None , instance_type = None
219
+ ):
212
220
"""Loads the JSON config for the given framework and image scope."""
221
+ if framework == HUGGING_FACE_FRAMEWORK and instance_type == "neuron" :
222
+ framework = framework + "_" + instance_type
213
223
config = config_for_framework (framework )
214
224
215
225
if accelerator_type :
@@ -250,6 +260,11 @@ def config_for_framework(framework):
250
260
return json .load (f )
251
261
252
262
263
+ def _get_latest_versions (list_of_versions ):
264
+ """Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
265
+ return sorted (list_of_versions , reverse = True )[0 ]
266
+
267
+
253
268
def _validate_accelerator_type (accelerator_type ):
254
269
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
255
270
if not accelerator_type .startswith ("ml.eia" ) and accelerator_type != "local_sagemaker_notebook" :
0 commit comments