Skip to content

Commit 7baad15

Browse files
committed
fix unit test
1 parent c05b46b commit 7baad15

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/sagemaker/fw_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -913,9 +913,10 @@ def _instance_type_supports_profiler(instance_type):
913913
Returns:
914914
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
915915
"""
916-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
917-
if match and match[1].startswith("trn"):
918-
return False
916+
if isinstance(instance_type, str):
917+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
918+
if match and match[1].startswith("trn"):
919+
return False
919920
return True
920921

921922

src/sagemaker/image_uris.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def retrieve(
152152
inference_tool = _get_inference_tool(inference_tool, instance_type)
153153
if inference_tool == "neuron":
154154
_framework = f"{framework}-{inference_tool}"
155+
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
155156
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
156157

157158
original_version = version
@@ -283,13 +284,20 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
283284
)
284285
image_scope = available_scopes[0]
285286

286-
if image_scope is None and framework not in TRAINIUM_ALLOWED_FRAMEWORKS:
287-
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
288-
289287
_validate_arg(image_scope, available_scopes, "image scope")
290288
return config if "scope" in config else config[image_scope]
291289

292290

291+
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
292+
"""Validate if framework is supported for the instance_type"""
293+
if (
294+
instace_type is not None
295+
and "trn" in instace_type
296+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
297+
):
298+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
299+
300+
293301
def config_for_framework(framework):
294302
"""Loads the JSON config for the given framework."""
295303
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))

0 commit comments

Comments
 (0)