@@ -79,13 +79,17 @@ def fixture_sagemaker_session():
79
79
80
80
81
81
def _get_full_gpu_image_uri (
82
- version , base_framework_version , instance_type , training_compiler_config
82
+ version ,
83
+ base_framework_version ,
84
+ instance_type ,
85
+ training_compiler_config ,
86
+ py_version
83
87
):
84
88
return image_uris .retrieve (
85
89
"huggingface" ,
86
90
REGION ,
87
91
version = version ,
88
- py_version = "py38" ,
92
+ py_version = py_version ,
89
93
instance_type = instance_type ,
90
94
image_scope = "training" ,
91
95
base_framework_version = base_framework_version ,
@@ -94,10 +98,10 @@ def _get_full_gpu_image_uri(
94
98
)
95
99
96
100
97
- def _create_train_job (version , base_framework_version , instance_type , training_compiler_config ):
101
+ def _create_train_job (version , base_framework_version , instance_type , training_compiler_config , py_version ):
98
102
return {
99
103
"image_uri" : _get_full_gpu_image_uri (
100
- version , base_framework_version , instance_type , training_compiler_config
104
+ version , base_framework_version , instance_type , training_compiler_config , py_version
101
105
),
102
106
"input_mode" : "File" ,
103
107
"input_config" : [
@@ -155,17 +159,18 @@ def _create_train_job(version, base_framework_version, instance_type, training_c
155
159
def test_unsupported_BYOC (
156
160
huggingface_training_compiler_version ,
157
161
huggingface_training_compiler_tensorflow_version ,
162
+ huggingface_training_compiler_py_version
158
163
):
159
164
byoc = (
160
- "1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
161
- "2.6.3-"
162
- "transformers4.17.0-gpu-"
163
- "py38 -cu112-ubuntu20.04"
165
+ f "1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
166
+ f "2.6.3-"
167
+ f "transformers4.17.0-gpu-"
168
+ f" { huggingface_training_compiler_py_version } -cu112-ubuntu20.04"
164
169
)
165
170
with pytest .raises (ValueError ):
166
171
HuggingFace (
167
172
image_uri = byoc ,
168
- py_version = "py38" ,
173
+ py_version = huggingface_training_compiler_py_version ,
169
174
entry_point = SCRIPT_PATH ,
170
175
role = ROLE ,
171
176
instance_count = INSTANCE_COUNT ,
@@ -181,10 +186,11 @@ def test_unsupported_cpu_instance(
181
186
cpu_instance_type ,
182
187
huggingface_training_compiler_version ,
183
188
huggingface_training_compiler_tensorflow_version ,
189
+ huggingface_training_compiler_py_version
184
190
):
185
191
with pytest .raises (ValueError ):
186
192
HuggingFace (
187
- py_version = "py38" ,
193
+ py_version = huggingface_training_compiler_py_version ,
188
194
entry_point = SCRIPT_PATH ,
189
195
role = ROLE ,
190
196
instance_count = INSTANCE_COUNT ,
@@ -201,10 +207,11 @@ def test_unsupported_gpu_instance(
201
207
unsupported_gpu_instance_class ,
202
208
huggingface_training_compiler_version ,
203
209
huggingface_training_compiler_tensorflow_version ,
210
+ huggingface_training_compiler_py_version
204
211
):
205
212
with pytest .raises (ValueError ):
206
213
HuggingFace (
207
- py_version = "py38" ,
214
+ py_version = huggingface_training_compiler_py_version ,
208
215
entry_point = SCRIPT_PATH ,
209
216
role = ROLE ,
210
217
instance_count = INSTANCE_COUNT ,
@@ -218,10 +225,11 @@ def test_unsupported_gpu_instance(
218
225
219
226
def test_unsupported_framework_version (
220
227
huggingface_training_compiler_version ,
228
+ huggingface_training_compiler_py_version
221
229
):
222
230
with pytest .raises (ValueError ):
223
231
HuggingFace (
224
- py_version = "py38" ,
232
+ py_version = huggingface_training_compiler_py_version ,
225
233
entry_point = SCRIPT_PATH ,
226
234
role = ROLE ,
227
235
instance_count = INSTANCE_COUNT ,
@@ -237,10 +245,11 @@ def test_unsupported_framework_version(
237
245
238
246
def test_unsupported_framework_mxnet (
239
247
huggingface_training_compiler_version ,
248
+ huggingface_training_compiler_py_version
240
249
):
241
250
with pytest .raises (ValueError ):
242
251
HuggingFace (
243
- py_version = "py38" ,
252
+ py_version = huggingface_training_compiler_py_version ,
244
253
entry_point = SCRIPT_PATH ,
245
254
role = ROLE ,
246
255
instance_count = INSTANCE_COUNT ,
@@ -254,7 +263,7 @@ def test_unsupported_framework_mxnet(
254
263
255
264
def test_unsupported_python_2 (
256
265
huggingface_training_compiler_version ,
257
- huggingface_training_compiler_tensorflow_version ,
266
+ huggingface_training_compiler_tensorflow_version
258
267
):
259
268
with pytest .raises (ValueError ):
260
269
HuggingFace (
@@ -282,12 +291,13 @@ def test_default_compiler_config(
282
291
huggingface_training_compiler_version ,
283
292
huggingface_training_compiler_tensorflow_version ,
284
293
instance_class ,
294
+ huggingface_training_compiler_py_version
285
295
):
286
296
compiler_config = TrainingCompilerConfig ()
287
297
instance_type = f"ml.{ instance_class } .xlarge"
288
298
289
299
hf = HuggingFace (
290
- py_version = "py38" ,
300
+ py_version = huggingface_training_compiler_py_version ,
291
301
entry_point = SCRIPT_PATH ,
292
302
role = ROLE ,
293
303
sagemaker_session = sagemaker_session ,
@@ -313,6 +323,7 @@ def test_default_compiler_config(
313
323
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
314
324
instance_type ,
315
325
compiler_config ,
326
+ huggingface_training_compiler_py_version
316
327
)
317
328
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
318
329
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -339,11 +350,12 @@ def test_debug_compiler_config(
339
350
sagemaker_session ,
340
351
huggingface_training_compiler_version ,
341
352
huggingface_training_compiler_tensorflow_version ,
353
+ huggingface_training_compiler_py_version
342
354
):
343
355
compiler_config = TrainingCompilerConfig (debug = True )
344
356
345
357
hf = HuggingFace (
346
- py_version = "py38" ,
358
+ py_version = huggingface_training_compiler_py_version ,
347
359
entry_point = SCRIPT_PATH ,
348
360
role = ROLE ,
349
361
sagemaker_session = sagemaker_session ,
@@ -369,6 +381,7 @@ def test_debug_compiler_config(
369
381
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
370
382
INSTANCE_TYPE ,
371
383
compiler_config ,
384
+ huggingface_training_compiler_py_version
372
385
)
373
386
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
374
387
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -395,11 +408,12 @@ def test_disable_compiler_config(
395
408
sagemaker_session ,
396
409
huggingface_training_compiler_version ,
397
410
huggingface_training_compiler_tensorflow_version ,
411
+ huggingface_training_compiler_py_version
398
412
):
399
413
compiler_config = TrainingCompilerConfig (enabled = False )
400
414
401
415
hf = HuggingFace (
402
- py_version = "py38" ,
416
+ py_version = huggingface_training_compiler_py_version ,
403
417
entry_point = SCRIPT_PATH ,
404
418
role = ROLE ,
405
419
sagemaker_session = sagemaker_session ,
@@ -425,6 +439,7 @@ def test_disable_compiler_config(
425
439
f"tensorflow{ huggingface_training_compiler_tensorflow_version } " ,
426
440
INSTANCE_TYPE ,
427
441
compiler_config ,
442
+ huggingface_training_compiler_py_version
428
443
)
429
444
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
430
445
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -448,12 +463,13 @@ def test_attach(
448
463
sagemaker_session ,
449
464
compiler_enabled ,
450
465
debug_enabled ,
466
+ huggingface_training_compiler_py_version
451
467
):
452
468
training_image = (
453
- "1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
454
- "2.6.3-"
455
- "transformers4.17.0-gpu-"
456
- "py38 -cu112-ubuntu20.04"
469
+ f "1.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-trcomp-training:"
470
+ f "2.6.3-"
471
+ f "transformers4.17.0-gpu-"
472
+ f" { huggingface_training_compiler_py_version } -cu112-ubuntu20.04"
457
473
)
458
474
returned_job_description = {
459
475
"AlgorithmSpecification" : {"TrainingInputMode" : "File" , "TrainingImage" : training_image },
0 commit comments