88
88
89
89
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
90
90
91
+
91
92
def pytest_addoption (parser ):
92
93
parser .addoption ("--sagemaker-client-config" , action = "store" , default = None )
93
94
parser .addoption ("--sagemaker-runtime-config" , action = "store" , default = None )
@@ -257,9 +258,13 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
257
258
258
259
259
260
@pytest .fixture (scope = "module" )
260
- def huggingface_training_compiler_pytorch_version (huggingface_training_compiler_version ):
261
+ def huggingface_training_compiler_pytorch_version (
262
+ huggingface_training_compiler_version ,
263
+ ):
261
264
versions = _huggingface_base_fm_version (
262
- huggingface_training_compiler_version , "pytorch" , "huggingface_training_compiler"
265
+ huggingface_training_compiler_version ,
266
+ "pytorch" ,
267
+ "huggingface_training_compiler" ,
263
268
)
264
269
if not versions :
265
270
pytest .skip (
@@ -270,9 +275,13 @@ def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_
270
275
271
276
272
277
@pytest .fixture (scope = "module" )
273
- def huggingface_training_compiler_tensorflow_version (huggingface_training_compiler_version ):
278
+ def huggingface_training_compiler_tensorflow_version (
279
+ huggingface_training_compiler_version ,
280
+ ):
274
281
versions = _huggingface_base_fm_version (
275
- huggingface_training_compiler_version , "tensorflow" , "huggingface_training_compiler"
282
+ huggingface_training_compiler_version ,
283
+ "tensorflow" ,
284
+ "huggingface_training_compiler" ,
276
285
)
277
286
if not versions :
278
287
pytest .skip (
@@ -294,19 +303,25 @@ def huggingface_training_compiler_tensorflow_py_version(
294
303
295
304
296
305
@pytest .fixture (scope = "module" )
297
- def huggingface_training_compiler_pytorch_py_version (huggingface_training_compiler_pytorch_version ):
306
+ def huggingface_training_compiler_pytorch_py_version (
307
+ huggingface_training_compiler_pytorch_version ,
308
+ ):
298
309
return "py38"
299
310
300
311
301
312
@pytest .fixture (scope = "module" )
302
- def huggingface_pytorch_latest_training_py_version (huggingface_training_pytorch_latest_version ):
313
+ def huggingface_pytorch_latest_training_py_version (
314
+ huggingface_training_pytorch_latest_version ,
315
+ ):
303
316
return (
304
317
"py38" if Version (huggingface_training_pytorch_latest_version ) >= Version ("1.9" ) else "py36"
305
318
)
306
319
307
320
308
321
@pytest .fixture (scope = "module" )
309
- def huggingface_pytorch_latest_inference_py_version (huggingface_inference_pytorch_latest_version ):
322
+ def huggingface_pytorch_latest_inference_py_version (
323
+ huggingface_inference_pytorch_latest_version ,
324
+ ):
310
325
return (
311
326
"py38"
312
327
if Version (huggingface_inference_pytorch_latest_version ) >= Version ("1.9" )
@@ -482,7 +497,8 @@ def pytorch_ddp_py_version():
482
497
483
498
484
499
@pytest .fixture (
485
- scope = "module" , params = ["1.10" , "1.10.0" , "1.10.2" , "1.11" , "1.11.0" , "1.12" , "1.12.0" ]
500
+ scope = "module" ,
501
+ params = ["1.10" , "1.10.0" , "1.10.2" , "1.11" , "1.11.0" , "1.12" , "1.12.0" ],
486
502
)
487
503
def pytorch_ddp_framework_version (request ):
488
504
return request .param
@@ -515,6 +531,7 @@ def gpu_instance_type(sagemaker_session, request):
515
531
else :
516
532
return "ml.p3.2xlarge"
517
533
534
+
518
535
@pytest .fixture ()
519
536
def gpu_pytorch_instance_type (sagemaker_session , request ):
520
537
if "pytorch_inference_version" in request .fixturenames :
@@ -531,6 +548,7 @@ def gpu_pytorch_instance_type(sagemaker_session, request):
531
548
else :
532
549
return "ml.p3.2xlarge"
533
550
551
+
534
552
@pytest .fixture (scope = "session" )
535
553
def gpu_instance_type_list (sagemaker_session , request ):
536
554
region = sagemaker_session .boto_session .region_name
0 commit comments