22
22
23
23
24
24
@pytest .fixture (scope = 'module' , name = 'pytorch_training_job' )
25
- def fixture_training_job (sagemaker_session , pytorch_full_version ):
25
+ def fixture_training_job (sagemaker_session , pytorch_full_version , instance_type ):
26
26
with timeout (minutes = 15 ):
27
27
pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
28
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
28
+ train_instance_count = 1 , train_instance_type = instance_type ,
29
29
sagemaker_session = sagemaker_session )
30
30
31
31
pytorch .fit ({'training' : _upload_training_data (pytorch )})
32
32
return pytorch .latest_training_job .name
33
33
34
34
35
- def test_sync_fit (sagemaker_session , pytorch_full_version ):
35
+ def test_sync_fit (sagemaker_session , pytorch_full_version , instance_type ):
36
36
training_job_name = ""
37
37
38
38
with timeout (minutes = 15 ):
39
39
pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
40
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
40
+ train_instance_count = 1 , train_instance_type = instance_type ,
41
41
sagemaker_session = sagemaker_session )
42
42
43
43
pytorch .fit ({'training' : _upload_training_data (pytorch )})
@@ -47,12 +47,12 @@ def test_sync_fit(sagemaker_session, pytorch_full_version):
47
47
PyTorch .attach (training_job_name , sagemaker_session = sagemaker_session )
48
48
49
49
50
- def test_async_fit (sagemaker_session , pytorch_full_version ):
50
+ def test_async_fit (sagemaker_session , pytorch_full_version , instance_type ):
51
51
training_job_name = ""
52
52
53
53
with timeout (minutes = 10 ):
54
54
pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
55
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
55
+ train_instance_count = 1 , train_instance_type = instance_type ,
56
56
sagemaker_session = sagemaker_session )
57
57
58
58
pytorch .fit ({'training' : _upload_training_data (pytorch )}, wait = False )
@@ -66,12 +66,12 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
66
66
PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
67
67
68
68
69
- def test_failed_training_job (sagemaker_session , pytorch_full_version ):
69
+ def test_failed_training_job (sagemaker_session , pytorch_full_version , instance_type ):
70
70
script_path = os .path .join (MNIST_DIR , 'failure_script.py' )
71
71
72
72
with timeout (minutes = 15 ):
73
73
pytorch = PyTorch (entry_point = script_path , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
74
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
74
+ train_instance_count = 1 , train_instance_type = instance_type ,
75
75
sagemaker_session = sagemaker_session )
76
76
77
77
with pytest .raises (ValueError ) as e :
0 commit comments