25
25
26
26
27
27
@pytest .fixture (scope = "module" )
28
- def chainer_local_training_job (sagemaker_local_session , chainer_full_version ):
29
- return _run_mnist_training_job (sagemaker_local_session , "local" , 1 , chainer_full_version )
28
+ def chainer_local_training_job (
29
+ sagemaker_local_session , chainer_full_version , chainer_full_py_version
30
+ ):
31
+ return _run_mnist_training_job (
32
+ sagemaker_local_session , "local" , 1 , chainer_full_version , chainer_full_py_version
33
+ )
30
34
31
35
32
36
@pytest .mark .local_mode
33
- def test_distributed_cpu_training (sagemaker_local_session , chainer_full_version ):
34
- _run_mnist_training_job (sagemaker_local_session , "local" , 2 , chainer_full_version )
37
+ def test_distributed_cpu_training (
38
+ sagemaker_local_session , chainer_full_version , chainer_full_py_version
39
+ ):
40
+ _run_mnist_training_job (
41
+ sagemaker_local_session , "local" , 2 , chainer_full_version , chainer_full_py_version
42
+ )
35
43
36
44
37
45
@pytest .mark .local_mode
@@ -129,7 +137,7 @@ def test_deploy_model(
129
137
130
138
131
139
def _run_mnist_training_job (
132
- sagemaker_session , instance_type , instance_count , chainer_version , py_version , wait = True
140
+ sagemaker_session , instance_type , instance_count , chainer_version , py_version
133
141
):
134
142
script_path = (
135
143
os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
@@ -156,7 +164,7 @@ def _run_mnist_training_job(
156
164
test_input = "file://" + os .path .join (data_path , "test" )
157
165
158
166
job_name = unique_name_from_base ("test-chainer-training" )
159
- chainer .fit ({"train" : train_input , "test" : test_input }, wait = wait , job_name = job_name )
167
+ chainer .fit ({"train" : train_input , "test" : test_input }, job_name = job_name )
160
168
return chainer
161
169
162
170
0 commit comments