@@ -79,11 +79,7 @@ def fixture_sagemaker_session():
79
79
80
80
81
81
def _get_full_gpu_image_uri (
82
- version ,
83
- base_framework_version ,
84
- instance_type ,
85
- training_compiler_config ,
86
- py_version
82
+ version , base_framework_version , instance_type , training_compiler_config , py_version
87
83
):
88
84
return image_uris .retrieve (
89
85
"huggingface" ,
@@ -98,7 +94,9 @@ def _get_full_gpu_image_uri(
98
94
)
99
95
100
96
101
- def _create_train_job (version , base_framework_version , instance_type , training_compiler_config , py_version ):
97
+ def _create_train_job (
98
+ version , base_framework_version , instance_type , training_compiler_config , py_version
99
+ ):
102
100
return {
103
101
"image_uri" : _get_full_gpu_image_uri (
104
102
version , base_framework_version , instance_type , training_compiler_config , py_version
@@ -159,7 +157,7 @@ def _create_train_job(version, base_framework_version, instance_type, training_c
159
157
def test_unsupported_BYOC (
160
158
huggingface_training_compiler_version ,
161
159
huggingface_training_compiler_tensorflow_version ,
162
- huggingface_training_compiler_py_version
160
+ huggingface_training_compiler_py_version ,
163
161
):
164
162
byoc = (
165
163
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
@@ -186,7 +184,7 @@ def test_unsupported_cpu_instance(
186
184
cpu_instance_type ,
187
185
huggingface_training_compiler_version ,
188
186
huggingface_training_compiler_tensorflow_version ,
189
- huggingface_training_compiler_py_version
187
+ huggingface_training_compiler_py_version ,
190
188
):
191
189
with pytest .raises (ValueError ):
192
190
HuggingFace (
@@ -207,7 +205,7 @@ def test_unsupported_gpu_instance(
207
205
unsupported_gpu_instance_class ,
208
206
huggingface_training_compiler_version ,
209
207
huggingface_training_compiler_tensorflow_version ,
210
- huggingface_training_compiler_py_version
208
+ huggingface_training_compiler_py_version ,
211
209
):
212
210
with pytest .raises (ValueError ):
213
211
HuggingFace (
@@ -225,7 +223,7 @@ def test_unsupported_gpu_instance(
225
223
226
224
def test_unsupported_framework_version (
227
225
huggingface_training_compiler_version ,
228
- huggingface_training_compiler_py_version
226
+ huggingface_training_compiler_py_version ,
229
227
):
230
228
with pytest .raises (ValueError ):
231
229
HuggingFace (
@@ -245,7 +243,7 @@ def test_unsupported_framework_version(
245
243
246
244
def test_unsupported_framework_mxnet (
247
245
huggingface_training_compiler_version ,
248
- huggingface_training_compiler_py_version
246
+ huggingface_training_compiler_py_version ,
249
247
):
250
248
with pytest .raises (ValueError ):
251
249
HuggingFace (
@@ -263,7 +261,7 @@ def test_unsupported_framework_mxnet(
263
261
264
262
def test_unsupported_python_2 (
265
263
huggingface_training_compiler_version ,
266
- huggingface_training_compiler_tensorflow_version
264
+ huggingface_training_compiler_tensorflow_version ,
267
265
):
268
266
with pytest .raises (ValueError ):
269
267
HuggingFace (
@@ -291,7 +289,7 @@ def test_default_compiler_config(
291
289
huggingface_training_compiler_version ,
292
290
huggingface_training_compiler_tensorflow_version ,
293
291
instance_class ,
294
- huggingface_training_compiler_py_version
292
+ huggingface_training_compiler_py_version ,
295
293
):
296
294
compiler_config = TrainingCompilerConfig ()
297
295
instance_type = f"ml.{ instance_class } .xlarge"
@@ -323,7 +321,7 @@ def test_default_compiler_config(
323
321
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
324
322
instance_type ,
325
323
compiler_config ,
326
- huggingface_training_compiler_py_version
324
+ huggingface_training_compiler_py_version ,
327
325
)
328
326
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
329
327
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -350,7 +348,7 @@ def test_debug_compiler_config(
350
348
sagemaker_session ,
351
349
huggingface_training_compiler_version ,
352
350
huggingface_training_compiler_tensorflow_version ,
353
- huggingface_training_compiler_py_version
351
+ huggingface_training_compiler_py_version ,
354
352
):
355
353
compiler_config = TrainingCompilerConfig (debug = True )
356
354
@@ -381,7 +379,7 @@ def test_debug_compiler_config(
381
379
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
382
380
INSTANCE_TYPE ,
383
381
compiler_config ,
384
- huggingface_training_compiler_py_version
382
+ huggingface_training_compiler_py_version ,
385
383
)
386
384
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
387
385
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -408,7 +406,7 @@ def test_disable_compiler_config(
408
406
sagemaker_session ,
409
407
huggingface_training_compiler_version ,
410
408
huggingface_training_compiler_tensorflow_version ,
411
- huggingface_training_compiler_py_version
409
+ huggingface_training_compiler_py_version ,
412
410
):
413
411
compiler_config = TrainingCompilerConfig (enabled = False )
414
412
@@ -439,7 +437,7 @@ def test_disable_compiler_config(
439
437
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
440
438
INSTANCE_TYPE ,
441
439
compiler_config ,
442
- huggingface_training_compiler_py_version
440
+ huggingface_training_compiler_py_version ,
443
441
)
444
442
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
445
443
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -460,10 +458,7 @@ def test_disable_compiler_config(
460
458
["compiler_enabled" , "debug_enabled" ], [(True , False ), (True , True ), (False , False )]
461
459
)
462
460
def test_attach (
463
- sagemaker_session ,
464
- compiler_enabled ,
465
- debug_enabled ,
466
- huggingface_training_compiler_py_version
461
+ sagemaker_session , compiler_enabled , debug_enabled , huggingface_training_compiler_py_version
467
462
):
468
463
training_image = (
469
464
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
@@ -503,7 +498,7 @@ def test_attach(
503
498
504
499
estimator = HuggingFace .attach (training_job_name = "hopper" , sagemaker_session = sagemaker_session )
505
500
assert estimator .latest_training_job .job_name == "hopper"
506
- assert estimator .py_version == "py38"
501
+ assert estimator .py_version == huggingface_training_compiler_py_version
507
502
assert estimator .framework_version == "4.17.0"
508
503
assert estimator .tensorflow_version == "2.6.3"
509
504
assert estimator .role == "arn:aws:iam::366:role/SageMakerRole"
0 commit comments