Skip to content

Commit 890984d

Browse files
fix: fix neo inferentia as compilation target not using framework version
1 parent 46c68a9 commit 890984d

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/sagemaker/model.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import json
1818
import logging
1919
import os
20-
import re
2120
import copy
2221
from typing import List, Dict
2322

@@ -47,6 +46,8 @@
4746
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
4847
)
4948

49+
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
50+
5051

5152
class ModelBase(abc.ABC):
5253
"""An object that encapsulates a trained model.
@@ -728,13 +729,22 @@ def _compilation_job_config(
728729
"Framework": framework.upper(),
729730
}
730731

731-
multiple_version_supported_framework_list = ["pytorch", "tensorflow"]
732-
if (
733-
framework.lower() in multiple_version_supported_framework_list
734-
and target_instance_type is not None
735-
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
736-
and framework_version is not None
732+
def multi_version_compilation_supported(
733+
target_instance_type: str, framework: str, framework_version: str
737734
):
735+
if target_instance_type and framework and framework_version:
736+
framework = framework.lower()
737+
multi_version_frameworks_support_mapping = {
738+
"inferentia": ["pytorch", "tensorflow", "mxnet"],
739+
"neo_ioc_targets": ["pytorch", "tensorflow"],
740+
}
741+
if target_instance_type in NEO_IOC_TARGET_DEVICES:
742+
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
743+
if target_instance_type == "ml_inf":
744+
return framework in multi_version_frameworks_support_mapping["inferentia"]
745+
return False
746+
747+
if multi_version_compilation_supported(target_instance_type, framework, framework_version):
738748
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
739749

740750
role = self.sagemaker_session.expand_role(role)

0 commit comments

Comments
 (0)