Skip to content

Commit bbaeb73

Browse files
mufaddal-rohawalaknikure
authored andcommitted
fix: huggingface release test (aws#3378)
1 parent bbc62f1 commit bbaeb73

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

tests/integ/test_training_compiler.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def skip_if_incompatible(gpu_instance_type, request):
8181
pytest.skip("no ml.p3 instances in this region")
8282

8383

84-
@pytest.mark.release
8584
@pytest.mark.parametrize(
8685
"gpu_instance_type,instance_count",
8786
[
@@ -130,6 +129,46 @@ def test_huggingface_pytorch(
130129
hf.fit(huggingface_dummy_dataset)
131130

132131

132+
@pytest.mark.release
133+
def test_huggingface_pytorch_release(
134+
sagemaker_session,
135+
gpu_instance_type,
136+
huggingface_training_compiler_latest_version,
137+
huggingface_training_compiler_pytorch_latest_version,
138+
huggingface_dummy_dataset,
139+
):
140+
"""
141+
Test the HuggingFace estimator with PyTorch
142+
"""
143+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
144+
data_path = os.path.join(DATA_DIR, "huggingface")
145+
146+
hf = HuggingFace(
147+
py_version="py38",
148+
entry_point=os.path.join(data_path, "run_glue.py"),
149+
role="SageMakerRole",
150+
transformers_version=huggingface_training_compiler_latest_version,
151+
pytorch_version=huggingface_training_compiler_pytorch_latest_version,
152+
instance_count=1,
153+
instance_type=gpu_instance_type,
154+
hyperparameters={
155+
"model_name_or_path": "distilbert-base-cased",
156+
"task_name": "wnli",
157+
"do_train": True,
158+
"do_eval": True,
159+
"max_seq_length": 128,
160+
"fp16": True,
161+
"per_device_train_batch_size": 128,
162+
"output_dir": "/opt/ml/model",
163+
},
164+
sagemaker_session=sagemaker_session,
165+
disable_profiler=True,
166+
compiler_config=HFTrainingCompilerConfig(),
167+
)
168+
169+
hf.fit(huggingface_dummy_dataset)
170+
171+
133172
@pytest.mark.release
134173
def test_huggingface_tensorflow(
135174
sagemaker_session,

0 commit comments

Comments
 (0)