Skip to content

Commit f212971

Browse files
Lokiiiiiishreyapandit
authored andcommitted
fix: HF estimator attach modified to work with py38 (aws#2697)
Co-authored-by: Shreya Pandit <[email protected]>
1 parent 1382f15 commit f212971

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/sagemaker/huggingface/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
312312
framework_version = None
313313
else:
314314
framework, pt_or_tf = framework.split("-")
315-
tag_pattern = re.compile("^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3[67]?)$")
315+
tag_pattern = re.compile(r"^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3\d*)$")
316316
tag_match = tag_pattern.match(tag)
317317
pt_or_tf_version = tag_match.group(1)
318318
framework_version = tag_match.group(2)

tests/unit/sagemaker/huggingface/test_estimator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,16 @@ def test_huggingface(
253253

254254

255255
def test_attach(
256-
sagemaker_session, huggingface_training_version, huggingface_pytorch_training_version
256+
sagemaker_session,
257+
huggingface_training_version,
258+
huggingface_pytorch_training_version,
259+
huggingface_pytorch_training_py_version,
257260
):
258261
training_image = (
259262
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
260263
f"{huggingface_pytorch_training_version}-"
261-
f"transformers{huggingface_training_version}-gpu-py36-cu110-ubuntu18.04"
264+
f"transformers{huggingface_training_version}-gpu-"
265+
f"{huggingface_pytorch_training_py_version}-cu110-ubuntu20.04"
262266
)
263267
returned_job_description = {
264268
"AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image},
@@ -290,7 +294,7 @@ def test_attach(
290294

291295
estimator = HuggingFace.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
292296
assert estimator.latest_training_job.job_name == "neo"
293-
assert estimator.py_version == "py36"
297+
assert estimator.py_version == huggingface_pytorch_training_py_version
294298
assert estimator.framework_version == huggingface_training_version
295299
assert estimator.pytorch_version == huggingface_pytorch_training_version
296300
assert estimator.role == "arn:aws:iam::366:role/SageMakerRole"

0 commit comments

Comments
 (0)