@@ -266,47 +266,26 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
266
266
267
267
@patch ('sagemaker.local.local_session.LocalSession' )
268
268
@patch ('sagemaker.local.image._stream_output' )
269
- @patch ('subprocess.Popen' )
270
269
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
271
270
@patch ('sagemaker.local.image._SageMakerContainer._download_folder' )
272
- def test_train_with_hyperparameters_without_job_name (_download_folder , _cleanup , popen , _stream_output , LocalSession ,
273
- tmpdir , sagemaker_session ):
271
+ def test_train_with_hyperparameters_without_job_name (_download_folder , _cleanup , _stream_output , LocalSession , tmpdir ):
274
272
275
273
directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
276
274
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
277
275
side_effect = directories ):
278
276
279
277
instance_count = 2
280
278
image = 'my-image'
281
- sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
279
+ sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = LocalSession )
282
280
sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS_WITHOUT_JOB_NAME , TRAINING_JOB_NAME )
283
281
284
- channel_dir = os .path .join (directories [1 ], 'b' )
285
- download_folder_calls = [call ('my-own-bucket' , 'prefix' , channel_dir )]
286
- _download_folder .assert_has_calls (download_folder_calls )
287
-
288
282
docker_compose_file = os .path .join (sagemaker_container .container_root , 'docker-compose.yaml' )
289
283
290
- call_args = popen .call_args [0 ][0 ]
291
- assert call_args is not None
292
-
293
- expected = ['docker-compose' , '-f' , docker_compose_file , 'up' , '--build' , '--abort-on-container-exit' ]
294
- for i , v in enumerate (expected ):
295
- assert call_args [i ] == v
296
-
297
284
with open (docker_compose_file , 'r' ) as f :
298
285
config = yaml .load (f )
299
- assert len (config ['services' ]) == instance_count
300
286
for h in sagemaker_container .hosts :
301
- assert config ['services' ][h ]['image' ] == image
302
- assert config ['services' ][h ]['command' ] == 'train'
303
- assert 'AWS_REGION={}' .format (REGION ) in config ['services' ][h ]['environment' ]
304
287
assert 'TRAINING_JOB_NAME={}' .format (TRAINING_JOB_NAME ) in config ['services' ][h ]['environment' ]
305
288
306
- # assert that expected by sagemaker container output directories exist
307
- assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output' ))
308
- assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output/data' ))
309
-
310
289
311
290
@patch ('sagemaker.local.local_session.LocalSession' )
312
291
@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
0 commit comments