|
17 | 17 | import json
|
18 | 18 | import logging
|
19 | 19 | import os
|
20 |
| -import re |
21 | 20 | import copy
|
22 | 21 | from typing import List, Dict
|
23 | 22 |
|
|
50 | 49 | ["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
|
51 | 50 | )
|
52 | 51 |
|
| 52 | +NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"] |
| 53 | + |
53 | 54 |
|
54 | 55 | class ModelBase(abc.ABC):
|
55 | 56 | """An object that encapsulates a trained model.
|
@@ -763,13 +764,22 @@ def _compilation_job_config(
|
763 | 764 | "Framework": framework.upper(),
|
764 | 765 | }
|
765 | 766 |
|
766 |
| - multiple_version_supported_framework_list = ["pytorch", "tensorflow"] |
767 |
| - if ( |
768 |
| - framework.lower() in multiple_version_supported_framework_list |
769 |
| - and target_instance_type is not None |
770 |
| - and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None |
771 |
| - and framework_version is not None |
| 767 | + def multi_version_compilation_supported( |
| 768 | + target_instance_type: str, framework: str, framework_version: str |
772 | 769 | ):
|
| 770 | + if target_instance_type and framework and framework_version: |
| 771 | + framework = framework.lower() |
| 772 | + multi_version_frameworks_support_mapping = { |
| 773 | + "inferentia": ["pytorch", "tensorflow", "mxnet"], |
| 774 | + "neo_ioc_targets": ["pytorch", "tensorflow"], |
| 775 | + } |
| 776 | + if target_instance_type in NEO_IOC_TARGET_DEVICES: |
| 777 | + return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"] |
| 778 | + if target_instance_type == "ml_inf": |
| 779 | + return framework in multi_version_frameworks_support_mapping["inferentia"] |
| 780 | + return False |
| 781 | + |
| 782 | + if multi_version_compilation_supported(target_instance_type, framework, framework_version): |
773 | 783 | input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
|
774 | 784 |
|
775 | 785 | role = self.sagemaker_session.expand_role(role)
|
|
0 commit comments