Skip to content

Commit 2815038

Browse files
fix: fix neo inferentia as compilation target not using framework version
1 parent d573e67 commit 2815038

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/sagemaker/model.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -728,13 +728,22 @@ def _compilation_job_config(
728728
"Framework": framework.upper(),
729729
}
730730

731-
multiple_version_supported_framework_list = ["pytorch", "tensorflow"]
732-
if (
733-
framework.lower() in multiple_version_supported_framework_list
734-
and target_instance_type is not None
735-
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
736-
and framework_version is not None
731+
def multi_version_compilation_supported(
732+
target_instance_type: str, framework: str, framework_version: str
737733
):
734+
if target_instance_type and framework and framework_version:
735+
framework = framework.lower()
736+
multi_version_frameworks_support_mapping = {
737+
"ml_inf": ["pytorch", "tensorflow", "mxnet"],
738+
"ml_ioc": ["pytorch", "tensorflow"],
739+
}
740+
if re.match("(?=^ml_)", target_instance_type):
741+
return framework in multi_version_frameworks_support_mapping["ml_ioc"]
742+
if re.match("(?=^ml_inf)", target_instance_type):
743+
return framework in multi_version_frameworks_support_mapping["ml_inf"]
744+
return False
745+
746+
if multi_version_compilation_supported(target_instance_type, framework, framework_version):
738747
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
739748

740749
role = self.sagemaker_session.expand_role(role)

0 commit comments

Comments
 (0)