20
20
from tests import integ
21
21
from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
22
22
from tests .integ .timeout import timeout
23
-
24
-
25
- @pytest .fixture (scope = "module" )
26
- def gpu_instance_type (request ):
27
- return "ml.p3.2xlarge"
23
+ from tests .integ .utils import gpu_list , retry_with_instance_list
28
24
29
25
30
26
@pytest .mark .release
31
27
@pytest .mark .skipif (
32
28
integ .test_region () not in integ .TRAINING_COMPILER_SUPPORTED_REGIONS ,
33
29
reason = "SageMaker Training Compiler is not supported in this region" ,
34
30
)
31
+ @pytest .mark .skipif (
32
+ integ .test_region () in integ .TRAINING_NO_P2_REGIONS
33
+ and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
34
+ reason = "no ml.p2 or ml.p3 instances in this region" ,
35
+ )
36
+ @retry_with_instance_list (gpu_list (integ .test_region ()))
35
37
def test_huggingface_pytorch (
36
38
sagemaker_session ,
37
- gpu_instance_type ,
38
39
huggingface_training_compiler_latest_version ,
39
40
huggingface_training_compiler_pytorch_latest_version ,
41
+ ** kwargs ,
40
42
):
41
43
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
42
44
data_path = os .path .join (DATA_DIR , "huggingface" )
@@ -48,7 +50,7 @@ def test_huggingface_pytorch(
48
50
transformers_version = huggingface_training_compiler_latest_version ,
49
51
pytorch_version = huggingface_training_compiler_pytorch_latest_version ,
50
52
instance_count = 1 ,
51
- instance_type = gpu_instance_type ,
53
+ instance_type = kwargs [ "instance_type" ] ,
52
54
hyperparameters = {
53
55
"model_name_or_path" : "distilbert-base-cased" ,
54
56
"task_name" : "wnli" ,
@@ -78,11 +80,17 @@ def test_huggingface_pytorch(
78
80
integ .test_region () not in integ .TRAINING_COMPILER_SUPPORTED_REGIONS ,
79
81
reason = "SageMaker Training Compiler is not supported in this region" ,
80
82
)
83
+ @pytest .mark .skipif (
84
+ integ .test_region () in integ .TRAINING_NO_P2_REGIONS
85
+ and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
86
+ reason = "no ml.p2 or ml.p3 instances in this region" ,
87
+ )
88
+ @retry_with_instance_list (gpu_list (integ .test_region ()))
81
89
def test_huggingface_tensorflow (
82
90
sagemaker_session ,
83
- gpu_instance_type ,
84
91
huggingface_training_compiler_latest_version ,
85
92
huggingface_training_compiler_tensorflow_latest_version ,
93
+ ** kwargs ,
86
94
):
87
95
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
88
96
data_path = os .path .join (DATA_DIR , "huggingface" )
@@ -94,7 +102,7 @@ def test_huggingface_tensorflow(
94
102
transformers_version = huggingface_training_compiler_latest_version ,
95
103
tensorflow_version = huggingface_training_compiler_tensorflow_latest_version ,
96
104
instance_count = 1 ,
97
- instance_type = gpu_instance_type ,
105
+ instance_type = kwargs [ "instance_type" ] ,
98
106
hyperparameters = {
99
107
"model_name_or_path" : "distilbert-base-cased" ,
100
108
"per_device_train_batch_size" : 128 ,
0 commit comments