@@ -152,6 +152,7 @@ def retrieve(
152
152
inference_tool = _get_inference_tool (inference_tool , instance_type )
153
153
if inference_tool == "neuron" :
154
154
_framework = f"{ framework } -{ inference_tool } "
155
+ _validate_for_suppported_frameworks_and_instance_type (framework , instance_type )
155
156
config = _config_for_framework_and_scope (_framework , image_scope , accelerator_type )
156
157
157
158
original_version = version
@@ -283,13 +284,20 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
283
284
)
284
285
image_scope = available_scopes [0 ]
285
286
286
- if image_scope is None and framework not in TRAINIUM_ALLOWED_FRAMEWORKS :
287
- _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" )
288
-
289
287
_validate_arg (image_scope , available_scopes , "image scope" )
290
288
return config if "scope" in config else config [image_scope ]
291
289
292
290
291
+ def _validate_for_suppported_frameworks_and_instance_type (framework , instace_type ):
292
+ """Validate if framework is supported for the instance_type"""
293
+ if (
294
+ instace_type is not None
295
+ and "trn" in instace_type
296
+ and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
297
+ ):
298
+ _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" )
299
+
300
+
293
301
def config_for_framework (framework ):
294
302
"""Loads the JSON config for the given framework."""
295
303
fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
0 commit comments