Skip to content

fix: fix: neo inferentia as compilation target not using framework ver #3183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import logging
import os
import re
import copy
from typing import List, Dict

Expand Down Expand Up @@ -50,6 +49,8 @@
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
)

NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]


class ModelBase(abc.ABC):
"""An object that encapsulates a trained model.
Expand Down Expand Up @@ -763,13 +764,22 @@ def _compilation_job_config(
"Framework": framework.upper(),
}

multiple_version_supported_framework_list = ["pytorch", "tensorflow"]
if (
framework.lower() in multiple_version_supported_framework_list
and target_instance_type is not None
and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None
and framework_version is not None
def multi_version_compilation_supported(
target_instance_type: str, framework: str, framework_version: str
):
if target_instance_type and framework and framework_version:
framework = framework.lower()
multi_version_frameworks_support_mapping = {
"inferentia": ["pytorch", "tensorflow", "mxnet"],
"neo_ioc_targets": ["pytorch", "tensorflow"],
}
if target_instance_type in NEO_IOC_TARGET_DEVICES:
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
if target_instance_type == "ml_inf":
return framework in multi_version_frameworks_support_mapping["inferentia"]
return False

if multi_version_compilation_supported(target_instance_type, framework, framework_version):
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)

role = self.sagemaker_session.expand_role(role)
Expand Down