Skip to content

Commit eb1c5ac

Browse files
author
Ignacio Quintero
committed
Fix unit tests
1 parent 5edbbcb commit eb1c5ac

File tree

4 files changed

+11
-31
lines changed

4 files changed

+11
-31
lines changed

tests/unit/test_chainer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,14 @@ def _get_full_gpu_image_uri(version):
6666

6767

6868
def _chainer_estimator(sagemaker_session, framework_version=defaults.CHAINER_VERSION, train_instance_type=None,
69-
enable_cloudwatch_metrics=False, base_job_name=None, use_mpi=None, num_processes=None,
69+
base_job_name=None, use_mpi=None, num_processes=None,
7070
process_slots_per_host=None, additional_mpi_options=None, **kwargs):
7171
return Chainer(entry_point=SCRIPT_PATH,
7272
framework_version=framework_version,
7373
role=ROLE,
7474
sagemaker_session=sagemaker_session,
7575
train_instance_count=INSTANCE_COUNT,
7676
train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE,
77-
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
7877
base_job_name=base_job_name,
7978
use_mpi=use_mpi,
8079
num_processes=num_processes,
@@ -152,7 +151,6 @@ def _create_train_job_with_additional_hyperparameters(version):
152151
},
153152
'hyperparameters': {
154153
'sagemaker_program': json.dumps('dummy_script.py'),
155-
'sagemaker_enable_cloudwatch_metrics': 'false',
156154
'sagemaker_container_log_level': str(logging.INFO),
157155
'sagemaker_job_name': json.dumps(JOB_NAME),
158156
'sagemaker_submit_directory':
@@ -225,12 +223,10 @@ def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_versi
225223
def test_create_model(sagemaker_session, chainer_version):
226224
container_log_level = '"logging.INFO"'
227225
source_dir = 's3://mybucket/source'
228-
enable_cloudwatch_metrics = 'true'
229226
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
230227
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
231228
framework_version=chainer_version, container_log_level=container_log_level,
232-
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir,
233-
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
229+
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)
234230

235231
job_name = 'new_name'
236232
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
@@ -244,19 +240,16 @@ def test_create_model(sagemaker_session, chainer_version):
244240
assert model.name == job_name
245241
assert model.container_log_level == container_log_level
246242
assert model.source_dir == source_dir
247-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
248243

249244

250245
def test_create_model_with_custom_image(sagemaker_session):
251246
container_log_level = '"logging.INFO"'
252247
source_dir = 's3://mybucket/source'
253-
enable_cloudwatch_metrics = 'true'
254248
custom_image = 'ubuntu:latest'
255249
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
256250
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
257251
image_name=custom_image, container_log_level=container_log_level,
258-
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir,
259-
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
252+
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)
260253

261254
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
262255
model = chainer.create_model()

tests/unit/test_mxnet.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ def _create_train_job(version):
101101
def test_create_model(sagemaker_session, mxnet_version):
102102
container_log_level = '"logging.INFO"'
103103
source_dir = 's3://mybucket/source'
104-
enable_cloudwatch_metrics = 'true'
105104
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
106105
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
107106
framework_version=mxnet_version, container_log_level=container_log_level,
108-
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
107+
base_job_name='job', source_dir=source_dir)
109108

110109
job_name = 'new_name'
111110
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
@@ -119,18 +118,16 @@ def test_create_model(sagemaker_session, mxnet_version):
119118
assert model.name == job_name
120119
assert model.container_log_level == container_log_level
121120
assert model.source_dir == source_dir
122-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
123121

124122

125123
def test_create_model_with_custom_image(sagemaker_session):
126124
container_log_level = '"logging.INFO"'
127125
source_dir = 's3://mybucket/source'
128-
enable_cloudwatch_metrics = 'true'
129126
custom_image = 'mxnet:2.0'
130127
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
131128
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
132129
image_name=custom_image, container_log_level=container_log_level,
133-
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
130+
base_job_name='job', source_dir=source_dir)
134131

135132
job_name = 'new_name'
136133
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
@@ -143,7 +140,6 @@ def test_create_model_with_custom_image(sagemaker_session):
143140
assert model.name == job_name
144141
assert model.container_log_level == container_log_level
145142
assert model.source_dir == source_dir
146-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
147143

148144

149145
@patch('time.strftime', return_value=TIMESTAMP)

tests/unit/test_pytorch.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION):
6464

6565

6666
def _pytorch_estimator(sagemaker_session, framework_version=defaults.PYTORCH_VERSION, train_instance_type=None,
67-
enable_cloudwatch_metrics=False, base_job_name=None, **kwargs):
67+
base_job_name=None, **kwargs):
6868
return PyTorch(entry_point=SCRIPT_PATH,
6969
framework_version=framework_version,
7070
py_version=PYTHON_VERSION,
7171
role=ROLE,
7272
sagemaker_session=sagemaker_session,
7373
train_instance_count=INSTANCE_COUNT,
7474
train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE,
75-
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
7675
base_job_name=base_job_name,
7776
**kwargs)
7877

@@ -119,11 +118,10 @@ def _create_train_job(version):
119118
def test_create_model(sagemaker_session, pytorch_version):
120119
container_log_level = '"logging.INFO"'
121120
source_dir = 's3://mybucket/source'
122-
enable_cloudwatch_metrics = 'true'
123121
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
124122
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
125123
framework_version=pytorch_version, container_log_level=container_log_level,
126-
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
124+
base_job_name='job', source_dir=source_dir)
127125

128126
job_name = 'new_name'
129127
pytorch.fit(inputs='s3://mybucket/train', job_name='new_name')
@@ -137,18 +135,16 @@ def test_create_model(sagemaker_session, pytorch_version):
137135
assert model.name == job_name
138136
assert model.container_log_level == container_log_level
139137
assert model.source_dir == source_dir
140-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
141138

142139

143140
def test_create_model_with_custom_image(sagemaker_session):
144141
container_log_level = '"logging.INFO"'
145142
source_dir = 's3://mybucket/source'
146-
enable_cloudwatch_metrics = 'true'
147143
image = 'pytorch:9000'
148144
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
149145
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
150146
container_log_level=container_log_level, image_name=image,
151-
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
147+
base_job_name='job', source_dir=source_dir)
152148

153149
job_name = 'new_name'
154150
pytorch.fit(inputs='s3://mybucket/train', job_name='new_name')
@@ -161,7 +157,6 @@ def test_create_model_with_custom_image(sagemaker_session):
161157
assert model.name == job_name
162158
assert model.container_log_level == container_log_level
163159
assert model.source_dir == source_dir
164-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
165160

166161

167162
@patch('time.strftime', return_value=TIMESTAMP)

tests/unit/test_tf_estimator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _create_train_job(tf_version):
107107

108108

109109
def _build_tf(sagemaker_session, framework_version=defaults.TF_VERSION, train_instance_type=None,
110-
checkpoint_path=None, enable_cloudwatch_metrics=False, base_job_name=None,
110+
checkpoint_path=None, base_job_name=None,
111111
training_steps=None, evaluation_steps=None, **kwargs):
112112
return TensorFlow(entry_point=SCRIPT_PATH,
113113
training_steps=training_steps,
@@ -118,7 +118,6 @@ def _build_tf(sagemaker_session, framework_version=defaults.TF_VERSION, train_in
118118
train_instance_count=INSTANCE_COUNT,
119119
train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE,
120120
checkpoint_path=checkpoint_path,
121-
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
122121
base_job_name=base_job_name,
123122
**kwargs)
124123

@@ -183,12 +182,11 @@ def test_tf_nonexistent_requirements_path(sagemaker_session):
183182
def test_create_model(sagemaker_session, tf_version):
184183
container_log_level = '"logging.INFO"'
185184
source_dir = 's3://mybucket/source'
186-
enable_cloudwatch_metrics = 'true'
187185
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
188186
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
189187
train_instance_type=INSTANCE_TYPE, framework_version=tf_version,
190188
container_log_level=container_log_level, base_job_name='job',
191-
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
189+
source_dir=source_dir)
192190

193191
job_name = 'doing something'
194192
tf.fit(inputs='s3://mybucket/train', job_name=job_name)
@@ -202,19 +200,17 @@ def test_create_model(sagemaker_session, tf_version):
202200
assert model.name == job_name
203201
assert model.container_log_level == container_log_level
204202
assert model.source_dir == source_dir
205-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
206203

207204

208205
def test_create_model_with_custom_image(sagemaker_session):
209206
container_log_level = '"logging.INFO"'
210207
source_dir = 's3://mybucket/source'
211-
enable_cloudwatch_metrics = 'true'
212208
custom_image = 'tensorflow:1.0'
213209
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
214210
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
215211
train_instance_type=INSTANCE_TYPE, image_name=custom_image,
216212
container_log_level=container_log_level, base_job_name='job',
217-
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
213+
source_dir=source_dir)
218214

219215
job_name = 'doing something'
220216
tf.fit(inputs='s3://mybucket/train', job_name=job_name)

0 commit comments

Comments
 (0)