Skip to content

Commit b320ded

Browse files
fix: integs for training compiler in non-PDX regions
1 parent 412b633 commit b320ded

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

tests/integ/test_training_compiler.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,25 @@
2020
from tests import integ
2121
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2222
from tests.integ.timeout import timeout
23-
24-
25-
@pytest.fixture(scope="module")
26-
def gpu_instance_type(request):
27-
return "ml.p3.2xlarge"
23+
from tests.integ.utils import gpu_list, retry_with_instance_list
2824

2925

3026
@pytest.mark.release
3127
@pytest.mark.skipif(
3228
integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS,
3329
reason="SageMaker Training Compiler is not supported in this region",
3430
)
31+
@pytest.mark.skipif(
32+
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
33+
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
34+
reason="no ml.p2 or ml.p3 instances in this region",
35+
)
36+
@retry_with_instance_list(gpu_list(integ.test_region()))
3537
def test_huggingface_pytorch(
3638
sagemaker_session,
37-
gpu_instance_type,
3839
huggingface_training_compiler_latest_version,
3940
huggingface_training_compiler_pytorch_latest_version,
41+
**kwargs,
4042
):
4143
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
4244
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -48,7 +50,7 @@ def test_huggingface_pytorch(
4850
transformers_version=huggingface_training_compiler_latest_version,
4951
pytorch_version=huggingface_training_compiler_pytorch_latest_version,
5052
instance_count=1,
51-
instance_type=gpu_instance_type,
53+
instance_type=kwargs["instance_type"],
5254
hyperparameters={
5355
"model_name_or_path": "distilbert-base-cased",
5456
"task_name": "wnli",
@@ -78,11 +80,17 @@ def test_huggingface_pytorch(
7880
integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS,
7981
reason="SageMaker Training Compiler is not supported in this region",
8082
)
83+
@pytest.mark.skipif(
84+
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
85+
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
86+
reason="no ml.p2 or ml.p3 instances in this region",
87+
)
88+
@retry_with_instance_list(gpu_list(integ.test_region()))
8189
def test_huggingface_tensorflow(
8290
sagemaker_session,
83-
gpu_instance_type,
8491
huggingface_training_compiler_latest_version,
8592
huggingface_training_compiler_tensorflow_latest_version,
93+
**kwargs,
8694
):
8795
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
8896
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -94,7 +102,7 @@ def test_huggingface_tensorflow(
94102
transformers_version=huggingface_training_compiler_latest_version,
95103
tensorflow_version=huggingface_training_compiler_tensorflow_latest_version,
96104
instance_count=1,
97-
instance_type=gpu_instance_type,
105+
instance_type=kwargs["instance_type"],
98106
hyperparameters={
99107
"model_name_or_path": "distilbert-base-cased",
100108
"per_device_train_batch_size": 128,

0 commit comments

Comments
 (0)