Skip to content

Commit cb2e2b9

Browse files
authored
Fix local image unit tests (aws#689)
1 parent 9c76287 commit cb2e2b9

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

tests/unit/test_image.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,13 @@ def test_check_output():
304304
@patch('sagemaker.local.local_session.LocalSession', Mock())
305305
@patch('sagemaker.local.image._stream_output', Mock())
306306
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
307-
@patch('sagemaker.local.data.get_data_source_instance', Mock())
307+
@patch('sagemaker.local.data.get_data_source_instance')
308308
@patch('subprocess.Popen')
309-
def test_train(popen, tmpdir, sagemaker_session):
309+
def test_train(popen, get_data_source_instance, tmpdir, sagemaker_session):
310+
data_source = Mock()
311+
data_source.get_root_dir.return_value = 'foo'
312+
get_data_source_instance.return_value = data_source
313+
310314
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
311315
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
312316
side_effect=directories):
@@ -342,8 +346,12 @@ def test_train(popen, tmpdir, sagemaker_session):
342346
@patch('sagemaker.local.local_session.LocalSession', Mock())
343347
@patch('sagemaker.local.image._stream_output', Mock())
344348
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
345-
@patch('sagemaker.local.data.get_data_source_instance', Mock())
346-
def test_train_with_hyperparameters_without_job_name(tmpdir, sagemaker_session):
349+
@patch('sagemaker.local.data.get_data_source_instance')
350+
def test_train_with_hyperparameters_without_job_name(get_data_source_instance, tmpdir, sagemaker_session):
351+
data_source = Mock()
352+
data_source.get_root_dir.return_value = 'foo'
353+
get_data_source_instance.return_value = data_source
354+
347355
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
348356
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
349357
side_effect=directories):
@@ -364,11 +372,14 @@ def test_train_with_hyperparameters_without_job_name(tmpdir, sagemaker_session):
364372
@patch('sagemaker.local.local_session.LocalSession', Mock())
365373
@patch('sagemaker.local.image._stream_output', side_effect=RuntimeError('this is expected'))
366374
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
367-
@patch('sagemaker.local.data.get_data_source_instance', Mock())
375+
@patch('sagemaker.local.data.get_data_source_instance')
368376
@patch('subprocess.Popen', Mock())
369-
def test_train_error(_stream_output, tmpdir, sagemaker_session):
370-
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
377+
def test_train_error(get_data_source_instance, _stream_output, tmpdir, sagemaker_session):
378+
data_source = Mock()
379+
data_source.get_root_dir.return_value = 'foo'
380+
get_data_source_instance.return_value = data_source
371381

382+
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
372383
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', side_effect=directories):
373384
instance_count = 2
374385
image = 'my-image'
@@ -384,9 +395,13 @@ def test_train_error(_stream_output, tmpdir, sagemaker_session):
384395
@patch('sagemaker.local.local_session.LocalSession', Mock())
385396
@patch('sagemaker.local.image._stream_output', Mock())
386397
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
387-
@patch('sagemaker.local.data.get_data_source_instance', Mock())
398+
@patch('sagemaker.local.data.get_data_source_instance')
388399
@patch('subprocess.Popen', Mock())
389-
def test_train_local_code(tmpdir, sagemaker_session):
400+
def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session):
401+
data_source = Mock()
402+
data_source.get_root_dir.return_value = 'foo'
403+
get_data_source_instance.return_value = data_source
404+
390405
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
391406
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
392407
side_effect=directories):
@@ -422,9 +437,13 @@ def test_train_local_code(tmpdir, sagemaker_session):
422437
@patch('sagemaker.local.local_session.LocalSession', Mock())
423438
@patch('sagemaker.local.image._stream_output', Mock())
424439
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
425-
@patch('sagemaker.local.data.get_data_source_instance', Mock())
440+
@patch('sagemaker.local.data.get_data_source_instance')
426441
@patch('subprocess.Popen', Mock())
427-
def test_train_local_intermediate_output(tmpdir, sagemaker_session):
442+
def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagemaker_session):
443+
data_source = Mock()
444+
data_source.get_root_dir.return_value = 'foo'
445+
get_data_source_instance.return_value = data_source
446+
428447
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
429448
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
430449
side_effect=directories):

0 commit comments

Comments
 (0)