Skip to content

Commit 0d04909

Browse files
nadiayaPiali Das
authored and
Piali Das
committed
Create output/data directory in local mode. (aws#364)
* Create output/data directory expected by sagemaker containers when running in local mode. * Update changelog. * Add output/data to directories when running training.
1 parent 99e031d commit 0d04909

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.9.3dev
6+
========
7+
8+
* bug-fix: Local Mode: Create output/data directory expected by SageMaker Container.
9+
510
1.9.2
611
=====
712

src/sagemaker/local/image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def train(self, input_data_config, hyperparameters):
8585
"""
8686
self.container_root = self._create_tmp_folder()
8787
os.mkdir(os.path.join(self.container_root, 'output'))
88+
# create output/data folder since sagemaker-containers 2.0 expects it
89+
os.mkdir(os.path.join(self.container_root, 'output', 'data'))
8890
# A shared directory for all the containers. It is only mounted if the training script is
8991
# Local.
9092
shared_dir = os.path.join(self.container_root, 'shared')
@@ -386,7 +388,7 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
386388
environment.extend(additional_env_vars)
387389

388390
if command == 'train':
389-
optml_dirs = {'output', 'input'}
391+
optml_dirs = {'output', 'output/data', 'input'}
390392

391393
services = {
392394
h: self._create_docker_host(h, environment, optml_dirs,

tests/unit/test_image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
245245
assert config['services'][h]['image'] == image
246246
assert config['services'][h]['command'] == 'train'
247247

248+
# assert that expected by sagemaker container output directories exist
249+
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output'))
250+
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output/data'))
251+
248252

249253
@patch('sagemaker.local.local_session.LocalSession')
250254
@patch('sagemaker.local.image._stream_output', side_effect=RuntimeError('this is expected'))

0 commit comments

Comments
 (0)