@@ -728,13 +728,22 @@ def _compilation_job_config(
728
728
"Framework" : framework .upper (),
729
729
}
730
730
731
- multiple_version_supported_framework_list = ["pytorch" , "tensorflow" ]
732
- if (
733
- framework .lower () in multiple_version_supported_framework_list
734
- and target_instance_type is not None
735
- and re .match ("(?=^ml_)(?!ml_inf)" , target_instance_type ) is not None
736
- and framework_version is not None
731
+ def multi_version_compilation_supported (
732
+ target_instance_type : str , framework : str , framework_version : str
737
733
):
734
+ if target_instance_type and framework and framework_version :
735
+ framework = framework .lower ()
736
+ multi_version_frameworks_support_mapping = {
737
+ "ml_inf" : ["pytorch" , "tensorflow" , "mxnet" ],
738
+ "ml_ioc" : ["pytorch" , "tensorflow" ],
739
+ }
740
+ if re .match ("(?=^ml_)" , target_instance_type ):
741
+ return framework in multi_version_frameworks_support_mapping ["ml_ioc" ]
742
+ if re .match ("(?=^ml_inf)" , target_instance_type ):
743
+ return framework in multi_version_frameworks_support_mapping ["ml_inf" ]
744
+ return False
745
+
746
+ if multi_version_compilation_supported (target_instance_type , framework , framework_version ):
738
747
input_model_config ["FrameworkVersion" ] = utils .get_short_version (framework_version )
739
748
740
749
role = self .sagemaker_session .expand_role (role )
0 commit comments