@@ -304,9 +304,13 @@ def test_check_output():
304
304
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
305
305
@patch ('sagemaker.local.image._stream_output' , Mock ())
306
306
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
307
- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
307
+ @patch ('sagemaker.local.data.get_data_source_instance' )
308
308
@patch ('subprocess.Popen' )
309
- def test_train (popen , tmpdir , sagemaker_session ):
309
+ def test_train (popen , get_data_source_instance , tmpdir , sagemaker_session ):
310
+ data_source = Mock ()
311
+ data_source .get_root_dir .return_value = 'foo'
312
+ get_data_source_instance .return_value = data_source
313
+
310
314
directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
311
315
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
312
316
side_effect = directories ):
@@ -342,8 +346,12 @@ def test_train(popen, tmpdir, sagemaker_session):
342
346
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
343
347
@patch ('sagemaker.local.image._stream_output' , Mock ())
344
348
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
345
- @patch ('sagemaker.local.data.get_data_source_instance' , Mock ())
346
- def test_train_with_hyperparameters_without_job_name (tmpdir , sagemaker_session ):
349
+ @patch ('sagemaker.local.data.get_data_source_instance' )
350
+ def test_train_with_hyperparameters_without_job_name (get_data_source_instance , tmpdir , sagemaker_session ):
351
+ data_source = Mock ()
352
+ data_source .get_root_dir .return_value = 'foo'
353
+ get_data_source_instance .return_value = data_source
354
+
347
355
directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
348
356
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
349
357
side_effect = directories ):
@@ -364,11 +372,14 @@ def test_train_with_hyperparameters_without_job_name(tmpdir, sagemaker_session):
364
372
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
365
373
@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
366
374
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
367
- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
375
+ @patch ('sagemaker.local.data.get_data_source_instance' )
368
376
@patch ('subprocess.Popen' , Mock ())
369
- def test_train_error (_stream_output , tmpdir , sagemaker_session ):
370
- directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
377
+ def test_train_error (get_data_source_instance , _stream_output , tmpdir , sagemaker_session ):
378
+ data_source = Mock ()
379
+ data_source .get_root_dir .return_value = 'foo'
380
+ get_data_source_instance .return_value = data_source
371
381
382
+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
372
383
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' , side_effect = directories ):
373
384
instance_count = 2
374
385
image = 'my-image'
@@ -384,9 +395,13 @@ def test_train_error(_stream_output, tmpdir, sagemaker_session):
384
395
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
385
396
@patch ('sagemaker.local.image._stream_output' , Mock ())
386
397
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
387
- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
398
+ @patch ('sagemaker.local.data.get_data_source_instance' )
388
399
@patch ('subprocess.Popen' , Mock ())
389
- def test_train_local_code (tmpdir , sagemaker_session ):
400
+ def test_train_local_code (get_data_source_instance , tmpdir , sagemaker_session ):
401
+ data_source = Mock ()
402
+ data_source .get_root_dir .return_value = 'foo'
403
+ get_data_source_instance .return_value = data_source
404
+
390
405
directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
391
406
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
392
407
side_effect = directories ):
@@ -422,9 +437,13 @@ def test_train_local_code(tmpdir, sagemaker_session):
422
437
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
423
438
@patch ('sagemaker.local.image._stream_output' , Mock ())
424
439
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
425
- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
440
+ @patch ('sagemaker.local.data.get_data_source_instance' )
426
441
@patch ('subprocess.Popen' , Mock ())
427
- def test_train_local_intermediate_output (tmpdir , sagemaker_session ):
442
+ def test_train_local_intermediate_output (get_data_source_instance , tmpdir , sagemaker_session ):
443
+ data_source = Mock ()
444
+ data_source .get_root_dir .return_value = 'foo'
445
+ get_data_source_instance .return_value = data_source
446
+
428
447
directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
429
448
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
430
449
side_effect = directories ):
0 commit comments