diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 60c766379b..fa30e4a27c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -17,7 +17,6 @@ import json import logging import os -import re import copy from typing import List, Dict @@ -50,6 +49,8 @@ ["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"] ) +NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"] + class ModelBase(abc.ABC): """An object that encapsulates a trained model. @@ -763,13 +764,22 @@ def _compilation_job_config( "Framework": framework.upper(), } - multiple_version_supported_framework_list = ["pytorch", "tensorflow"] - if ( - framework.lower() in multiple_version_supported_framework_list - and target_instance_type is not None - and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None - and framework_version is not None + def multi_version_compilation_supported( + target_instance_type: str, framework: str, framework_version: str ): + if target_instance_type and framework and framework_version: + framework = framework.lower() + multi_version_frameworks_support_mapping = { + "inferentia": ["pytorch", "tensorflow", "mxnet"], + "neo_ioc_targets": ["pytorch", "tensorflow"], + } + if target_instance_type in NEO_IOC_TARGET_DEVICES: + return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"] + if target_instance_type == "ml_inf": + return framework in multi_version_frameworks_support_mapping["inferentia"] + return False + + if multi_version_compilation_supported(target_instance_type, framework, framework_version): input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version) role = self.sagemaker_session.expand_role(role)