Skip to content

Commit 385d40a

Browse files
authored
Fix bug in local mode (#429)
* Fix bug in localmode Set correct default values for additional_env_vars and additional_volumes and add unit test. * Modify changelog
1 parent 9d3c218 commit 385d40a

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

CHANGELOG.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ CHANGELOG
77

88
* feature: Local Mode: Add support for Batch Inference
99
* feature: Add timestamp to secondary status in training job output
10+
* bug-fix: Local Mode: Set correct default values for additional_volumes and additional_env_vars
1011
* enhancement: Local Mode: support nvidia-docker2 natively
1112

12-
1313
1.11.2
1414
======
1515

src/sagemaker/local/image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
341341
342342
"""
343343
boto_session = self.sagemaker_session.boto_session
344-
additional_env_vars = additional_env_vars or []
345-
additional_volumes = additional_volumes or {}
344+
additional_volumes = additional_volumes or []
345+
additional_env_vars = additional_env_vars or {}
346346
environment = []
347347
optml_dirs = set()
348348

src/sagemaker/utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def secondary_training_status_changed(current_job_description, prev_job_descript
160160

161161

162162
def secondary_training_status_message(job_description, prev_description):
163-
"""Returns a string contains start time and the secondary training job status message.
163+
"""Returns a string contains last modified time and the secondary training job status message.
164164
165165
Args:
166166
job_description: Returned response from DescribeTrainingJob call
@@ -181,8 +181,12 @@ def secondary_training_status_message(job_description, prev_description):
181181
if prev_description_secondary_transitions is not None else 0
182182
current_transitions = job_description['SecondaryStatusTransitions']
183183

184-
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
185-
current_transitions[prev_transitions_num - len(current_transitions):]
184+
if len(current_transitions) == prev_transitions_num:
185+
# Secondary status is not changed but the message changed.
186+
transitions_to_print = current_transitions[-1:]
187+
else:
188+
# Secondary status is changed we need to print all the entries.
189+
transitions_to_print = current_transitions[prev_transitions_num - len(current_transitions):]
186190

187191
status_strs = []
188192
for transition in transitions_to_print:

tests/unit/test_image.py

+22
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,28 @@ def test_serve_local_code(up, copy, copytree, tmpdir, sagemaker_session):
413413
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
414414

415415

416+
@patch('sagemaker.local.image._HostingContainer.run')
417+
@patch('shutil.copy')
418+
@patch('shutil.copytree')
419+
def test_serve_local_code_no_env(up, copy, copytree, tmpdir, sagemaker_session):
420+
421+
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
422+
return_value=str(tmpdir.mkdir('container-root'))):
423+
424+
image = 'my-image'
425+
sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session)
426+
sagemaker_container.serve('/some/model/path', {})
427+
docker_compose_file = os.path.join(sagemaker_container.container_root,
428+
'docker-compose.yaml')
429+
430+
with open(docker_compose_file, 'r') as f:
431+
config = yaml.load(f)
432+
433+
for h in sagemaker_container.hosts:
434+
assert config['services'][h]['image'] == image
435+
assert config['services'][h]['command'] == 'serve'
436+
437+
416438
@patch('sagemaker.utils.download_file')
417439
@patch('tarfile.is_tarfile')
418440
@patch('tarfile.open', MagicMock())

0 commit comments

Comments
 (0)