@@ -302,7 +302,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
302
302
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
303
303
@patch ("time.time" , return_value = TIME )
304
304
def test_pytorch (
305
- time , name_from_base , sagemaker_session , pytorch_inference_version , pytorch_inference_py_version
305
+ time , name_from_base , sagemaker_session , pytorch_inference_version , pytorch_inference_py_version , gpu_pytorch_instance_type
306
306
):
307
307
pytorch = PyTorch (
308
308
entry_point = SCRIPT_PATH ,
@@ -339,24 +339,24 @@ def test_pytorch(
339
339
REGION ,
340
340
version = pytorch_inference_version ,
341
341
py_version = pytorch_inference_py_version ,
342
- instance_type = GPU ,
342
+ instance_type = gpu_pytorch_instance_type ,
343
343
image_scope = "inference" ,
344
344
)
345
345
346
- actual_environment = model .prepare_container_def (GPU )
346
+ actual_environment = model .prepare_container_def (gpu_pytorch_instance_type )
347
347
submit_directory = actual_environment ["Environment" ]["SAGEMAKER_SUBMIT_DIRECTORY" ]
348
348
model_url = actual_environment ["ModelDataUrl" ]
349
349
expected_environment = _get_environment (submit_directory , model_url , expected_image_uri )
350
350
assert actual_environment == expected_environment
351
351
352
352
assert "cpu" in model .prepare_container_def (CPU )["Image" ]
353
- predictor = pytorch .deploy (1 , GPU )
353
+ predictor = pytorch .deploy (1 , gpu_pytorch_instance_type )
354
354
assert isinstance (predictor , PyTorchPredictor )
355
355
356
356
357
357
@patch ("sagemaker.utils.repack_model" , MagicMock ())
358
358
@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
359
- def test_model (sagemaker_session , pytorch_inference_version , pytorch_inference_py_version ):
359
+ def test_model (sagemaker_session , pytorch_inference_version , pytorch_inference_py_version , gpu_pytorch_instance_type ):
360
360
model = PyTorchModel (
361
361
MODEL_DATA ,
362
362
role = ROLE ,
@@ -365,21 +365,22 @@ def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_p
365
365
py_version = pytorch_inference_py_version ,
366
366
sagemaker_session = sagemaker_session ,
367
367
)
368
- predictor = model .deploy (1 , GPU )
368
+ predictor = model .deploy (1 , gpu_pytorch_instance_type )
369
369
assert isinstance (predictor , PyTorchPredictor )
370
370
371
371
372
372
@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
373
373
@patch ("sagemaker.utils.repack_model" )
374
- def test_mms_model (repack_model , sagemaker_session ):
374
+ @pytest .mark .parametrize ("gpu_pytorch_instance_type" , ["1.2" ], indirect = True )
375
+ def test_mms_model (repack_model , sagemaker_session , gpu_pytorch_instance_type ):
375
376
PyTorchModel (
376
377
MODEL_DATA ,
377
378
role = ROLE ,
378
379
entry_point = SCRIPT_PATH ,
379
380
sagemaker_session = sagemaker_session ,
380
381
framework_version = "1.2" ,
381
382
py_version = "py3" ,
382
- ).deploy (1 , GPU )
383
+ ).deploy (1 , gpu_pytorch_instance_type )
383
384
384
385
repack_model .assert_called_with (
385
386
dependencies = [],
@@ -428,6 +429,7 @@ def test_model_custom_serialization(
428
429
sagemaker_session ,
429
430
pytorch_inference_version ,
430
431
pytorch_inference_py_version ,
432
+ gpu_pytorch_instance_type
431
433
):
432
434
model = PyTorchModel (
433
435
MODEL_DATA ,
@@ -441,7 +443,7 @@ def test_model_custom_serialization(
441
443
custom_deserializer = Mock ()
442
444
predictor = model .deploy (
443
445
1 ,
444
- GPU ,
446
+ gpu_pytorch_instance_type ,
445
447
serializer = custom_serializer ,
446
448
deserializer = custom_deserializer ,
447
449
)
0 commit comments