18
18
import pytest
19
19
import numpy
20
20
21
+ from sagemaker .chainer .defaults import CHAINER_VERSION
21
22
from sagemaker .chainer .estimator import Chainer
22
23
from sagemaker .chainer .model import ChainerModel
23
24
from sagemaker .utils import sagemaker_timestamp
26
27
27
28
28
29
@pytest .fixture (scope = 'module' )
29
- def chainer_training_job (sagemaker_session ):
30
- return _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 )
30
+ def chainer_training_job (sagemaker_session , chainer_full_version ):
31
+ return _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 , chainer_full_version )
31
32
32
33
33
- def test_distributed_cpu_training (sagemaker_session ):
34
- _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 2 )
34
+ def test_distributed_cpu_training (sagemaker_session , chainer_full_version ):
35
+ _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 2 , chainer_full_version )
35
36
36
37
37
- def test_distributed_gpu_training (sagemaker_session ):
38
- _run_mnist_training_job (sagemaker_session , "ml.p2.xlarge" , 2 )
38
+ def test_distributed_gpu_training (sagemaker_session , chainer_full_version ):
39
+ _run_mnist_training_job (sagemaker_session , "ml.p2.xlarge" , 2 , chainer_full_version )
39
40
40
41
41
- def test_training_with_additional_hyperparameters (sagemaker_session ):
42
+ def test_training_with_additional_hyperparameters (sagemaker_session , chainer_full_version ):
42
43
with timeout (minutes = 15 ):
43
44
script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' )
44
45
data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
45
46
46
47
chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
47
48
train_instance_count = 1 , train_instance_type = "ml.c4.xlarge" ,
49
+ framework_version = chainer_full_version ,
48
50
sagemaker_session = sagemaker_session , hyperparameters = {'epochs' : 1 },
49
51
use_mpi = True ,
50
52
num_processes = 2 ,
@@ -75,8 +77,7 @@ def test_deploy_model(chainer_training_job, sagemaker_session):
75
77
desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = chainer_training_job )
76
78
model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
77
79
script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' )
78
- model = ChainerModel (model_data , 'SageMakerRole' , entry_point = script_path ,
79
- sagemaker_session = sagemaker_session )
80
+ model = ChainerModel (model_data , 'SageMakerRole' , entry_point = script_path , sagemaker_session = sagemaker_session )
80
81
predictor = model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name )
81
82
_predict_and_assert (predictor )
82
83
@@ -85,7 +86,8 @@ def test_async_fit(sagemaker_session):
85
86
endpoint_name = 'test-chainer-attach-deploy-{}' .format (sagemaker_timestamp ())
86
87
87
88
with timeout (minutes = 5 ):
88
- training_job_name = _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 , wait = False )
89
+ training_job_name = _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 ,
90
+ chainer_full_version = CHAINER_VERSION , wait = False )
89
91
90
92
print ("Waiting to re-attach to the training job: %s" % training_job_name )
91
93
time .sleep (20 )
@@ -97,12 +99,13 @@ def test_async_fit(sagemaker_session):
97
99
_predict_and_assert (predictor )
98
100
99
101
100
- def test_failed_training_job (sagemaker_session ):
102
+ def test_failed_training_job (sagemaker_session , chainer_full_version ):
101
103
with timeout (minutes = 15 ):
102
104
script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'failure_script.py' )
103
105
data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
104
106
105
107
chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
108
+ framework_version = chainer_full_version ,
106
109
train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
107
110
sagemaker_session = sagemaker_session )
108
111
@@ -113,7 +116,8 @@ def test_failed_training_job(sagemaker_session):
113
116
chainer .fit (train_input )
114
117
115
118
116
- def _run_mnist_training_job (sagemaker_session , instance_type , instance_count , wait = True ):
119
+ def _run_mnist_training_job (sagemaker_session , instance_type , instance_count ,
120
+ chainer_full_version , wait = True ):
117
121
with timeout (minutes = 15 ):
118
122
119
123
script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' ) if instance_type == 1 else \
@@ -122,6 +126,7 @@ def _run_mnist_training_job(sagemaker_session, instance_type, instance_count, wa
122
126
data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
123
127
124
128
chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
129
+ framework_version = chainer_full_version ,
125
130
train_instance_count = instance_count , train_instance_type = instance_type ,
126
131
sagemaker_session = sagemaker_session , hyperparameters = {'epochs' : 1 })
127
132
0 commit comments