Skip to content

Commit bbd7984

Browse files
committed
Fix unit test
1 parent 4e02b12 commit bbd7984

File tree

6 files changed

+52
-35
lines changed

6 files changed

+52
-35
lines changed

tests/unit/test_chainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import sys
2020
from distutils.util import strtobool
21-
from mock import Mock
21+
from mock import MagicMock, Mock
2222
from mock import patch
2323

2424

@@ -282,6 +282,7 @@ def test_create_model_with_custom_image(sagemaker_session):
282282
assert model.image == custom_image
283283

284284

285+
@patch('sagemaker.utils.create_tar_file', MagicMock())
285286
@patch('time.strftime', return_value=TIMESTAMP)
286287
def test_chainer(strftime, sagemaker_session, chainer_version):
287288
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
@@ -321,6 +322,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
321322
assert isinstance(predictor, ChainerPredictor)
322323

323324

325+
@patch('sagemaker.utils.create_tar_file', MagicMock())
324326
def test_model(sagemaker_session):
325327
model = ChainerModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
326328
sagemaker_session=sagemaker_session)

tests/unit/test_estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from time import sleep
1919

2020
import pytest
21-
from mock import Mock, patch
21+
from mock import MagicMock, Mock, patch
2222

2323
from sagemaker.estimator import Estimator, Framework, _TrainingJob
2424
from sagemaker.model import FrameworkModel
@@ -127,6 +127,12 @@ def prepare_container_def(self, instance_type):
127127
return MODEL_CONTAINER_DEF
128128

129129

130+
@pytest.fixture(autouse=True)
131+
def mock_create_tar_file():
132+
with patch('sagemaker.utils.create_tar_file', MagicMock()) as create_tar_file:
133+
yield create_tar_file
134+
135+
130136
@pytest.fixture()
131137
def sagemaker_session():
132138
boto_mock = Mock(name='boto_session', region_name=REGION)

tests/unit/test_mxnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import os
1919
import pytest
20-
from mock import Mock
20+
from mock import MagicMock, Mock
2121
from mock import patch
2222

2323
from sagemaker.mxnet import defaults
@@ -101,6 +101,7 @@ def _create_train_job(version):
101101
}
102102

103103

104+
@patch('sagemaker.utils.create_tar_file', MagicMock())
104105
def test_create_model(sagemaker_session, mxnet_version):
105106
container_log_level = '"logging.INFO"'
106107
source_dir = 's3://mybucket/source'
@@ -168,6 +169,7 @@ def test_create_model_with_custom_image(sagemaker_session):
168169
assert model.source_dir == source_dir
169170

170171

172+
@patch('sagemaker.utils.create_tar_file', MagicMock())
171173
@patch('time.strftime', return_value=TIMESTAMP)
172174
def test_mxnet(strftime, sagemaker_session, mxnet_version):
173175
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
@@ -207,6 +209,7 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version):
207209
assert isinstance(predictor, MXNetPredictor)
208210

209211

212+
@patch('sagemaker.utils.create_tar_file', MagicMock())
210213
def test_model(sagemaker_session):
211214
model = MXNetModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
212215
sagemaker_session=sagemaker_session)

tests/unit/test_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import pytest
1919
import sys
20-
from mock import Mock
20+
from mock import MagicMock, Mock
2121
from mock import patch
2222

2323
from sagemaker.pytorch import defaults
@@ -184,6 +184,7 @@ def test_create_model_with_custom_image(sagemaker_session):
184184
assert model.source_dir == source_dir
185185

186186

187+
@patch('sagemaker.utils.create_tar_file', MagicMock())
187188
@patch('time.strftime', return_value=TIMESTAMP)
188189
def test_pytorch(strftime, sagemaker_session, pytorch_version):
189190
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
@@ -223,6 +224,7 @@ def test_pytorch(strftime, sagemaker_session, pytorch_version):
223224
assert isinstance(predictor, PyTorchPredictor)
224225

225226

227+
@patch('sagemaker.utils.create_tar_file', MagicMock())
226228
def test_model(sagemaker_session):
227229
model = PyTorchModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
228230
sagemaker_session=sagemaker_session)

tests/unit/test_tf_estimator.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import os
1818

1919
import pytest
20-
from mock import patch, Mock
20+
from mock import patch, Mock, MagicMock
2121

22-
from sagemaker.fw_utils import create_image_uri, UploadedCode
22+
from sagemaker.fw_utils import create_image_uri
2323
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2424
from sagemaker.session import s3_input
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
@@ -162,6 +162,7 @@ def test_tf_support_gpu_instances(sagemaker_session, tf_version):
162162
assert tf.train_image() == _get_full_gpu_image_uri(tf_version)
163163

164164

165+
@patch('sagemaker.utils.create_tar_file', MagicMock())
165166
def test_tf_deploy_model_server_workers(sagemaker_session):
166167
tf = _build_tf(sagemaker_session)
167168
tf.fit(inputs=s3_input('s3://mybucket/train'))
@@ -172,6 +173,7 @@ def test_tf_deploy_model_server_workers(sagemaker_session):
172173
MODEL_SERVER_WORKERS_PARAM_NAME.upper()]
173174

174175

176+
@patch('sagemaker.utils.create_tar_file', MagicMock())
175177
def test_tf_deploy_model_server_workers_unset(sagemaker_session):
176178
tf = _build_tf(sagemaker_session)
177179
tf.fit(inputs=s3_input('s3://mybucket/train'))
@@ -259,20 +261,16 @@ def test_create_model_with_custom_image(sagemaker_session):
259261
assert model.image == custom_image
260262

261263

262-
@patch('time.strftime', return_value=TIMESTAMP)
263-
@patch('time.time', return_value=TIME)
264-
@patch('sagemaker.estimator.tar_and_upload_dir')
265-
@patch('sagemaker.model.tar_and_upload_dir')
266-
def test_tf(m_tar, e_tar, time, strftime, sagemaker_session, tf_version):
264+
@patch('sagemaker.utils.create_tar_file', MagicMock())
265+
@patch('time.strftime', MagicMock(return_value=TIMESTAMP))
266+
@patch('time.time', MagicMock(return_value=TIME))
267+
def test_tf(sagemaker_session, tf_version):
267268
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000,
268269
evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
269270
framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR)
270271

271272
inputs = 's3://mybucket/train'
272-
s3_prefix = 's3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)
273-
e_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
274-
s3_prefix = 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)
275-
m_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
273+
276274
tf.fit(inputs=inputs)
277275

278276
call_names = [c[0] for c in sagemaker_session.method_calls]
@@ -288,7 +286,8 @@ def test_tf(m_tar, e_tar, time, strftime, sagemaker_session, tf_version):
288286

289287
environment = {
290288
'Environment': {
291-
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME),
289+
'SAGEMAKER_SUBMIT_DIRECTORY':
290+
's3://mybucket/sagemaker-tensorflow-2017-11-06-14:14:15.673/source/sourcedir.tar.gz',
292291
'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_REQUIREMENTS': 'dummy_requirements.txt',
293292
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', 'SAGEMAKER_REGION': 'us-west-2',
294293
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
@@ -318,6 +317,7 @@ def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, pope
318317
'following command: \n pip install tensorboard'
319318

320319

320+
@patch('sagemaker.utils.create_tar_file', MagicMock())
321321
def test_model(sagemaker_session, tf_version):
322322
model = TensorFlowModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
323323
sagemaker_session=sagemaker_session)
@@ -340,6 +340,7 @@ def test_run_tensorboard_locally_without_awscli_binary(time, strftime, popen, ca
340340
'following command: \n pip install awscli'
341341

342342

343+
@patch('sagemaker.utils.create_tar_file', MagicMock())
343344
@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories')
344345
@patch('tempfile.mkdtemp', return_value='/my/temp/folder')
345346
@patch('shutil.rmtree')
@@ -362,6 +363,7 @@ def test_run_tensorboard_locally(sleep, time, strftime, popen, call, access, rmt
362363
stdout=-1)
363364

364365

366+
@patch('sagemaker.utils.create_tar_file', MagicMock())
365367
@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories')
366368
@patch('tempfile.mkdtemp', return_value='/my/temp/folder')
367369
@patch('shutil.rmtree')
@@ -388,6 +390,7 @@ def test_run_tensorboard_locally_port_in_use(sleep, time, strftime, popen, call,
388390
stderr=-1, stdout=-1)
389391

390392

393+
@patch('sagemaker.utils.create_tar_file', MagicMock())
391394
def test_tf_checkpoint_not_set(sagemaker_session):
392395
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
393396
tf = _build_tf(sagemaker_session, checkpoint_path=None, base_job_name=job_name,
@@ -398,6 +401,7 @@ def test_tf_checkpoint_not_set(sagemaker_session):
398401
assert tf.hyperparameters()['checkpoint_path'] == expected_result
399402

400403

404+
@patch('sagemaker.utils.create_tar_file', MagicMock())
401405
def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
402406
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
403407
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)
@@ -408,6 +412,7 @@ def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
408412
assert tf.hyperparameters()['evaluation_steps'] == 'null'
409413

410414

415+
@patch('sagemaker.utils.create_tar_file', MagicMock())
411416
def test_tf_training_and_evaluation_steps(sagemaker_session):
412417
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
413418
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)
@@ -418,11 +423,13 @@ def test_tf_training_and_evaluation_steps(sagemaker_session):
418423
assert tf.hyperparameters()['evaluation_steps'] == '456'
419424

420425

426+
@patch('sagemaker.utils.create_tar_file', MagicMock())
421427
def test_tf_checkpoint_set(sagemaker_session):
422428
tf = _build_tf(sagemaker_session, checkpoint_path='s3://my_checkpoint_bucket')
423429
assert tf.hyperparameters()['checkpoint_path'] == json.dumps("s3://my_checkpoint_bucket")
424430

425431

432+
@patch('sagemaker.utils.create_tar_file', MagicMock())
426433
def test_train_image_default(sagemaker_session):
427434
tf = TensorFlow(entry_point=SCRIPT_PATH,
428435
role=ROLE,
@@ -433,6 +440,7 @@ def test_train_image_default(sagemaker_session):
433440
assert _get_full_cpu_image_uri(defaults.TF_VERSION) in tf.train_image()
434441

435442

443+
@patch('sagemaker.utils.create_tar_file', MagicMock())
436444
def test_attach(sagemaker_session, tf_version):
437445
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:{}-cpu-py2'.format(tf_version)
438446
rjd = {
@@ -483,6 +491,7 @@ def test_attach(sagemaker_session, tf_version):
483491
assert estimator.checkpoint_path == 's3://other/1508872349'
484492

485493

494+
@patch('sagemaker.utils.create_tar_file', MagicMock())
486495
def test_attach_new_repo_name(sagemaker_session, tf_version):
487496
training_image = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:{}-cpu-py2'.format(tf_version)
488497
rjd = {
@@ -531,6 +540,7 @@ def test_attach_new_repo_name(sagemaker_session, tf_version):
531540
assert estimator.train_image() == training_image
532541

533542

543+
@patch('sagemaker.utils.create_tar_file', MagicMock())
534544
def test_attach_old_container(sagemaker_session):
535545
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0'
536546
rjd = {
@@ -707,18 +717,16 @@ def test_script_mode_create_model(create_tfs_model, sagemaker_session):
707717
create_tfs_model.assert_called_once()
708718

709719

720+
@patch('sagemaker.utils.create_tar_file', MagicMock())
710721
@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories')
711722
@patch('sagemaker.tensorflow.estimator.Tensorboard.start')
712-
@patch('tempfile.mkdtemp', return_value='/my/temp/folder')
713-
@patch('shutil.rmtree')
714723
@patch('os.access', return_value=True)
715724
@patch('subprocess.call')
716725
@patch('subprocess.Popen')
717726
@patch('time.strftime', return_value=TIMESTAMP)
718727
@patch('time.time', return_value=TIME)
719728
@patch('time.sleep')
720-
def test_script_mode_tensorboard(sleep, time, strftime, popen, call, access, rmtree, mkdtemp,
721-
start, sync, sagemaker_session):
729+
def test_script_mode_tensorboard(sleep, time, strftime, popen, call, access, start, sync, sagemaker_session):
722730
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
723731
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
724732
framework_version='some_version', script_mode=True)
@@ -729,18 +737,13 @@ def test_script_mode_tensorboard(sleep, time, strftime, popen, call, access, rmt
729737

730738
@patch('time.strftime', return_value=TIMESTAMP)
731739
@patch('time.time', return_value=TIME)
732-
@patch('sagemaker.estimator.tar_and_upload_dir')
733-
@patch('sagemaker.model.tar_and_upload_dir')
734-
def test_tf_script_mode(m_tar, e_tar, time, strftime, sagemaker_session):
740+
@patch('sagemaker.utils.create_tar_file', MagicMock())
741+
def test_tf_script_mode(time, strftime, sagemaker_session):
735742
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, py_version='py3',
736743
train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version='1.11',
737744
source_dir=DATA_DIR)
738745

739746
inputs = 's3://mybucket/train'
740-
s3_prefix = 's3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, SM_JOB_NAME)
741-
e_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
742-
s3_prefix = 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, SM_JOB_NAME)
743-
m_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
744747
tf.fit(inputs=inputs)
745748

746749
call_names = [c[0] for c in sagemaker_session.method_calls]
@@ -755,18 +758,13 @@ def test_tf_script_mode(m_tar, e_tar, time, strftime, sagemaker_session):
755758

756759
@patch('time.strftime', return_value=TIMESTAMP)
757760
@patch('time.time', return_value=TIME)
758-
@patch('sagemaker.estimator.tar_and_upload_dir')
759-
@patch('sagemaker.model.tar_and_upload_dir')
760-
def test_tf_script_mode_ps(m_tar, e_tar, time, strftime, sagemaker_session):
761+
@patch('sagemaker.utils.create_tar_file', MagicMock())
762+
def test_tf_script_mode_ps(time, strftime, sagemaker_session):
761763
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, py_version='py3',
762764
train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version='1.11',
763765
source_dir=DATA_DIR, distributions=DISTRIBUTION_ENABLED)
764766

765767
inputs = 's3://mybucket/train'
766-
s3_prefix = 's3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, SM_JOB_NAME)
767-
e_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
768-
s3_prefix = 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, SM_JOB_NAME)
769-
m_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
770768
tf.fit(inputs=inputs)
771769

772770
call_names = [c[0] for c in sagemaker_session.method_calls]

tests/unit/test_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from mock import Mock, patch
16+
from mock import MagicMock, Mock, patch
1717

1818
from sagemaker.transformer import Transformer, _TransformJob
1919

@@ -40,6 +40,12 @@
4040
}
4141

4242

43+
@pytest.fixture(autouse=True)
44+
def mock_create_tar_file():
45+
with patch('sagemaker.utils.create_tar_file', MagicMock()) as create_tar_file:
46+
yield create_tar_file
47+
48+
4349
@pytest.fixture()
4450
def sagemaker_session():
4551
boto_mock = Mock(name='boto_session')

0 commit comments

Comments
 (0)