|
17 | 17 | import json
|
18 | 18 | import logging
|
19 | 19 | import os
|
20 |
| -import re |
21 | 20 | import copy
|
22 | 21 | from typing import List, Dict
|
23 | 22 |
|
|
47 | 46 | ["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
|
48 | 47 | )
|
49 | 48 |
|
| 49 | +NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"] |
| 50 | + |
50 | 51 |
|
51 | 52 | class ModelBase(abc.ABC):
|
52 | 53 | """An object that encapsulates a trained model.
|
@@ -728,13 +729,22 @@ def _compilation_job_config(
|
728 | 729 | "Framework": framework.upper(),
|
729 | 730 | }
|
730 | 731 |
|
731 |
| - multiple_version_supported_framework_list = ["pytorch", "tensorflow"] |
732 |
| - if ( |
733 |
| - framework.lower() in multiple_version_supported_framework_list |
734 |
| - and target_instance_type is not None |
735 |
| - and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None |
736 |
| - and framework_version is not None |
| 732 | + def multi_version_compilation_supported( |
| 733 | + target_instance_type: str, framework: str, framework_version: str |
737 | 734 | ):
|
| 735 | + if target_instance_type and framework and framework_version: |
| 736 | + framework = framework.lower() |
| 737 | + multi_version_frameworks_support_mapping = { |
| 738 | + "inferentia": ["pytorch", "tensorflow", "mxnet"], |
| 739 | + "neo_ioc_targets": ["pytorch", "tensorflow"], |
| 740 | + } |
| 741 | + if target_instance_type in NEO_IOC_TARGET_DEVICES: |
| 742 | + return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"] |
| 743 | + if target_instance_type == "ml_inf": |
| 744 | + return framework in multi_version_frameworks_support_mapping["inferentia"] |
| 745 | + return False |
| 746 | + |
| 747 | + if multi_version_compilation_supported(target_instance_type, framework, framework_version): |
738 | 748 | input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
|
739 | 749 |
|
740 | 750 | role = self.sagemaker_session.expand_role(role)
|
|
0 commit comments