diff --git a/src/sagemaker/huggingface/training_compiler/config.py b/src/sagemaker/huggingface/training_compiler/config.py index 135ea6edc1..961084b313 100644 --- a/src/sagemaker/huggingface/training_compiler/config.py +++ b/src/sagemaker/huggingface/training_compiler/config.py @@ -26,7 +26,7 @@ class TrainingCompilerConfig(BaseConfig): """The SageMaker Training Compiler configuration class.""" - SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ "ml.g4dn.8xlarge", "ml.g4dn.12xlarge", @@ -87,10 +87,7 @@ def __init__( super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) @classmethod - def validate( - cls, - estimator, - ): + def validate(cls, estimator): """Checks if SageMaker Training Compiler is configured correctly. Args: diff --git a/src/sagemaker/tensorflow/training_compiler/config.py b/src/sagemaker/tensorflow/training_compiler/config.py index d14cc3359b..16c4b1fe70 100644 --- a/src/sagemaker/tensorflow/training_compiler/config.py +++ b/src/sagemaker/tensorflow/training_compiler/config.py @@ -24,14 +24,10 @@ class TrainingCompilerConfig(BaseConfig): """The SageMaker Training Compiler configuration class.""" - SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"] + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] MIN_SUPPORTED_VERSION = "2.9" - def __init__( - self, - enabled=True, - debug=False, - ): + def __init__(self, enabled=True, debug=False): """This class initializes a ``TrainingCompilerConfig`` instance. `Amazon SageMaker Training Compiler @@ -79,10 +75,7 @@ def __init__( super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) @classmethod - def validate( - cls, - estimator, - ): + def validate(cls, estimator): """Checks if SageMaker Training Compiler is configured correctly. Args: diff --git a/src/sagemaker/training_compiler/config.py b/src/sagemaker/training_compiler/config.py index dcfd85471e..1067084441 100644 --- a/src/sagemaker/training_compiler/config.py +++ b/src/sagemaker/training_compiler/config.py @@ -23,16 +23,12 @@ class TrainingCompilerConfig(object): """The SageMaker Training Compiler configuration class.""" DEBUG_PATH = "/opt/ml/output/data/compiler/" - SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled" HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode" - def __init__( - self, - enabled=True, - debug=False, - ): + def __init__(self, enabled=True, debug=False): """This class initializes a ``TrainingCompilerConfig`` instance. `Amazon SageMaker Training Compiler @@ -118,10 +114,7 @@ def _to_hyperparameter_dict(self): return compiler_config_hyperparameters @classmethod - def validate( - cls, - estimator, - ): + def validate(cls, estimator): """Checks if SageMaker Training Compiler is configured correctly. Args: @@ -138,19 +131,20 @@ def validate( warn_msg = ( "Estimator instance_type is a PipelineVariable (%s), " "which has to be interpreted as one of the " - "[p3, g4dn, p4d, g5] classes in execution time." + "%s classes in execution time." + ) + logger.warning( + warn_msg, + type(estimator.instance_type), + str(cls.SUPPORTED_INSTANCE_CLASS_PREFIXES).replace(",", ""), ) - logger.warning(warn_msg, type(estimator.instance_type)) elif estimator.instance_type: if "local" not in estimator.instance_type: requested_instance_class = estimator.instance_type.split(".")[ 1 ] # Expecting ml.class.size if not any( - [ - requested_instance_class.startswith(i) - for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES - ] + [requested_instance_class == i for i in cls.SUPPORTED_INSTANCE_CLASS_PREFIXES] ): error_helper_string = ( "Unsupported Instance class {}."