diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 2ba00fae83..b3ea66095f 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -312,7 +312,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na framework_version = None else: framework, pt_or_tf = framework.split("-") - tag_pattern = re.compile("^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3[67]?)$") + tag_pattern = re.compile(r"^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3\d*)$") tag_match = tag_pattern.match(tag) pt_or_tf_version = tag_match.group(1) framework_version = tag_match.group(2) diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index d4f366681c..749afcc776 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -253,12 +253,16 @@ def test_huggingface( def test_attach( - sagemaker_session, huggingface_training_version, huggingface_pytorch_training_version + sagemaker_session, + huggingface_training_version, + huggingface_pytorch_training_version, + huggingface_pytorch_training_py_version, ): training_image = ( f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:" f"{huggingface_pytorch_training_version}-" - f"transformers{huggingface_training_version}-gpu-py36-cu110-ubuntu18.04" + f"transformers{huggingface_training_version}-gpu-" + f"{huggingface_pytorch_training_py_version}-cu110-ubuntu20.04" ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -290,7 +294,7 @@ def test_attach( estimator = HuggingFace.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" - assert estimator.py_version == "py36" + assert estimator.py_version == huggingface_pytorch_training_py_version assert estimator.framework_version == huggingface_training_version assert estimator.pytorch_version == huggingface_pytorch_training_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole"