Skip to content

Commit 24bbbdf

Browse files
HappyAmazonianJoseJuan98
authored andcommitted
fix: fix: neo inferentia as compilation target not using framework ver (aws#3183)
1 parent 1164c60 commit 24bbbdf

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/sagemaker/model.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import json
1818
import logging
1919
import os
20-
import re
2120
import copy
2221
from typing import List, Dict
2322

@@ -50,6 +49,8 @@
5049
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
5150
)
5251

52+
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
53+
5354

5455
class ModelBase(abc.ABC):
5556
"""An object that encapsulates a trained model.
@@ -763,13 +764,22 @@ def _compilation_job_config(
763764
"Framework": framework.upper(),
764765
}
765766

766-
multiple_version_supported_framework_list = ["pytorch", "tensorflow"]
767-
if (
768-
framework.lower() in multiple_version_supported_framework_list
769-
and target_instance_type is not None
770-
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
771-
and framework_version is not None
767+
def multi_version_compilation_supported(
768+
target_instance_type: str, framework: str, framework_version: str
772769
):
770+
if target_instance_type and framework and framework_version:
771+
framework = framework.lower()
772+
multi_version_frameworks_support_mapping = {
773+
"inferentia": ["pytorch", "tensorflow", "mxnet"],
774+
"neo_ioc_targets": ["pytorch", "tensorflow"],
775+
}
776+
if target_instance_type in NEO_IOC_TARGET_DEVICES:
777+
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
778+
if target_instance_type == "ml_inf":
779+
return framework in multi_version_frameworks_support_mapping["inferentia"]
780+
return False
781+
782+
if multi_version_compilation_supported(target_instance_type, framework, framework_version):
773783
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
774784

775785
role = self.sagemaker_session.expand_role(role)

0 commit comments

Comments
 (0)