Skip to content

Commit 3d091b4

Browse files
authored
local mode: improve training input/output (aws#449)
output_path is now honored: Can be either file:// or s3:// - This also changes the default behavior of local mode to use the SDK provided default S3 bucket if nothing is passed. This makes it easier for customers to create models in SageMaker too since their Model Artifacts will already be a tarfile in S3. input_channel content_type is now honored in the same way as SageMaker. If it is not provided it is not passed to the container. Before we were always passing 'application/octet-stream'
1 parent 507f2cd commit 3d091b4

16 files changed

+382
-227
lines changed

CHANGELOG.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
CHANGELOG
33
=========
44

5-
1.13.0
6-
======
5+
========
6+
1.13.dev
7+
========
78

89
* feature: Estimator: add input mode to training channels
910
* feature: Estimator: add model_uri and model_channel_name parameters
11+
* enhancement: Local Mode: support output_path. Can be either file:// or s3://
1012

1113
1.12.0
1214
======

src/sagemaker/estimator.py

+3
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(self, role, train_instance_count, train_instance_type,
118118

119119
self.base_job_name = base_job_name
120120
self._current_job_name = None
121+
if (not self.sagemaker_session.local_mode
122+
and output_path and output_path.startswith('file://')):
123+
raise RuntimeError('file:// output paths are only supported in Local Mode')
121124
self.output_path = output_path
122125
self.output_kms_key = output_kms_key
123126
self.latest_training_job = None

src/sagemaker/fw_utils.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414

1515
import os
1616
import re
17-
import tarfile
18-
import tempfile
1917
from collections import namedtuple
2018
from six.moves.urllib.parse import urlparse
2119

22-
from sagemaker.utils import name_from_image
20+
import sagemaker.utils
2321

2422
"""This module contains utility functions shared across ``Framework`` components."""
2523

@@ -128,14 +126,9 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
128126
s3 = session.resource('s3')
129127
key = '{}/{}'.format(s3_key_prefix, 'sourcedir.tar.gz')
130128

131-
with tempfile.TemporaryFile() as f:
132-
with tarfile.open(mode='w:gz', fileobj=f) as t:
133-
for sf in source_files:
134-
# Add all files from the directory into the root of the directory structure of the tar
135-
t.add(sf, arcname=os.path.basename(sf))
136-
# Need to reset the file descriptor position after writing to prepare for read
137-
f.seek(0)
138-
s3.Object(bucket, key).put(Body=f)
129+
tar_file = sagemaker.utils.create_tar_file(source_files)
130+
s3.Object(bucket, key).upload_file(tar_file)
131+
os.remove(tar_file)
139132

140133
return UploadedCode(s3_prefix='s3://{}/{}'.format(bucket, key), script_name=script_name)
141134

@@ -226,7 +219,8 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
226219
Returns:
227220
str: the key prefix to be used in uploading code
228221
"""
229-
return '/'.join(filter(None, [code_location_key_prefix, model_name or name_from_image(image)]))
222+
training_job_name = sagemaker.utils.name_from_image(image)
223+
return '/'.join(filter(None, [code_location_key_prefix, model_name or training_job_name]))
230224

231225

232226
def empty_framework_version_warning(default_version):

src/sagemaker/local/data.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import platform
1617
import sys
1718
import tempfile
1819
from abc import ABCMeta
@@ -162,6 +163,12 @@ def __init__(self, bucket, prefix, sagemaker_session):
162163
root_dir = os.path.abspath(root_dir)
163164

164165
working_dir = tempfile.mkdtemp(dir=root_dir)
166+
# Docker cannot mount Mac OS /var folder properly see
167+
# https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600
168+
# Only apply this workaround if the user didn't provide an alternate storage root dir.
169+
if root_dir is None and platform.system() == 'Darwin':
170+
working_dir = '/private{}'.format(working_dir)
171+
165172
sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session)
166173
self.files = LocalFileDataSource(working_dir)
167174

src/sagemaker/local/entities.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,28 @@ def __init__(self, container):
4646
self.start_time = None
4747
self.end_time = None
4848

49-
def start(self, input_data_config, hyperparameters, job_name):
49+
def start(self, input_data_config, output_data_config, hyperparameters, job_name):
5050
for channel in input_data_config:
5151
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
5252
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
53+
data_uri = channel['DataSource']['S3DataSource']['S3Uri']
5354
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
5455
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
56+
data_uri = channel['DataSource']['FileDataSource']['FileUri']
5557
else:
5658
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
5759

60+
# use a single Data URI - this makes handling S3 and File Data easier down the stack
61+
channel['DataUri'] = data_uri
62+
5863
if data_distribution != 'FullyReplicated':
5964
raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' %
6065
data_distribution)
6166

6267
self.start = datetime.datetime.now()
6368
self.state = self._TRAINING
6469

65-
self.model_artifacts = self.container.train(input_data_config, hyperparameters, job_name)
70+
self.model_artifacts = self.container.train(input_data_config, output_data_config, hyperparameters, job_name)
6671
self.end = datetime.datetime.now()
6772
self.state = self._COMPLETED
6873

@@ -298,7 +303,7 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs):
298303
if 'AssembleWith' in output_data and output_data['AssembleWith'] == 'Line':
299304
f.write(b'\n')
300305

301-
move_to_destination(working_dir, output_data['S3OutputPath'], self.local_session)
306+
move_to_destination(working_dir, output_data['S3OutputPath'], self.name, self.local_session)
302307
self.container.stop_serving()
303308

304309

src/sagemaker/local/image.py

+75-66
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
import yaml
3434

3535
import sagemaker
36+
import sagemaker.local.data
37+
import sagemaker.local.utils
38+
import sagemaker.utils
3639

3740
CONTAINER_PREFIX = 'algo'
3841
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
@@ -78,7 +81,7 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
7881
self.container_root = None
7982
self.container = None
8083

81-
def train(self, input_data_config, hyperparameters, job_name):
84+
def train(self, input_data_config, output_data_config, hyperparameters, job_name):
8285
"""Run a training job locally using docker-compose.
8386
Args:
8487
input_data_config (dict): The Input Data Configuration, this contains data such as the
@@ -126,23 +129,17 @@ def train(self, input_data_config, hyperparameters, job_name):
126129
msg = "Failed to run: %s, %s" % (compose_command, str(e))
127130
raise RuntimeError(msg)
128131

129-
s3_artifacts = self.retrieve_artifacts(compose_data)
132+
artifacts = self.retrieve_artifacts(compose_data, output_data_config, job_name)
130133

131134
# free up the training data directory as it may contain
132135
# lots of data downloaded from S3. This doesn't delete any local
133136
# data that was just mounted to the container.
134-
_delete_tree(data_dir)
135-
_delete_tree(shared_dir)
136-
# Also free the container config files.
137-
for host in self.hosts:
138-
container_config_path = os.path.join(self.container_root, host)
139-
_delete_tree(container_config_path)
140-
141-
self._cleanup()
142-
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you
137+
dirs_to_delete = [data_dir, shared_dir]
138+
self._cleanup(dirs_to_delete)
139+
# Print our Job Complete line to have a similar experience to training on SageMaker where you
143140
# see this line at the end.
144141
print('===== Job Complete =====')
145-
return s3_artifacts
142+
return artifacts
146143

147144
def serve(self, model_dir, environment):
148145
"""Host a local endpoint using docker-compose.
@@ -188,7 +185,7 @@ def stop_serving(self):
188185
# for serving we can delete everything in the container root.
189186
_delete_tree(self.container_root)
190187

191-
def retrieve_artifacts(self, compose_data):
188+
def retrieve_artifacts(self, compose_data, output_data_config, job_name):
192189
"""Get the model artifacts from all the container nodes.
193190
194191
Used after training completes to gather the data from all the individual containers. As the
@@ -201,26 +198,49 @@ def retrieve_artifacts(self, compose_data):
201198
Returns: Local path to the collected model artifacts.
202199
203200
"""
204-
# Grab the model artifacts from all the Nodes.
205-
s3_artifacts = os.path.join(self.container_root, 's3_artifacts')
206-
os.mkdir(s3_artifacts)
201+
# We need a directory to store the artfiacts from all the nodes
202+
# and another one to contained the compressed final artifacts
203+
artifacts = os.path.join(self.container_root, 'artifacts')
204+
compressed_artifacts = os.path.join(self.container_root, 'compressed_artifacts')
205+
os.mkdir(artifacts)
206+
207+
model_artifacts = os.path.join(artifacts, 'model')
208+
output_artifacts = os.path.join(artifacts, 'output')
207209

208-
s3_model_artifacts = os.path.join(s3_artifacts, 'model')
209-
s3_output_artifacts = os.path.join(s3_artifacts, 'output')
210-
os.mkdir(s3_model_artifacts)
211-
os.mkdir(s3_output_artifacts)
210+
artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts]
211+
for d in artifact_dirs:
212+
os.mkdir(d)
212213

214+
# Gather the artifacts from all nodes into artifacts/model and artifacts/output
213215
for host in self.hosts:
214216
volumes = compose_data['services'][str(host)]['volumes']
215-
216217
for volume in volumes:
217218
host_dir, container_dir = volume.split(':')
218219
if container_dir == '/opt/ml/model':
219-
sagemaker.local.utils.recursive_copy(host_dir, s3_model_artifacts)
220+
sagemaker.local.utils.recursive_copy(host_dir, model_artifacts)
220221
elif container_dir == '/opt/ml/output':
221-
sagemaker.local.utils.recursive_copy(host_dir, s3_output_artifacts)
222+
sagemaker.local.utils.recursive_copy(host_dir, output_artifacts)
222223

223-
return s3_model_artifacts
224+
# Tar Artifacts -> model.tar.gz and output.tar.gz
225+
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)]
226+
output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)]
227+
sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz'))
228+
sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz'))
229+
230+
if output_data_config['S3OutputPath'] == '':
231+
output_data = 'file://%s' % compressed_artifacts
232+
else:
233+
# Now we just need to move the compressed artifacts to wherever they are required
234+
output_data = sagemaker.local.utils.move_to_destination(
235+
compressed_artifacts,
236+
output_data_config['S3OutputPath'],
237+
job_name,
238+
self.sagemaker_session)
239+
240+
_delete_tree(model_artifacts)
241+
_delete_tree(output_artifacts)
242+
243+
return os.path.join(output_data, 'model.tar.gz')
224244

225245
def write_config_files(self, host, hyperparameters, input_data_config):
226246
"""Write the config files for the training containers.
@@ -235,17 +255,22 @@ def write_config_files(self, host, hyperparameters, input_data_config):
235255
Returns: None
236256
237257
"""
238-
239258
config_path = os.path.join(self.container_root, host, 'input', 'config')
240259

241260
resource_config = {
242261
'current_host': host,
243262
'hosts': self.hosts
244263
}
245264

246-
json_input_data_config = {
247-
c['ChannelName']: {'ContentType': 'application/octet-stream'} for c in input_data_config
248-
}
265+
print(input_data_config)
266+
json_input_data_config = {}
267+
for c in input_data_config:
268+
channel_name = c['ChannelName']
269+
json_input_data_config[channel_name] = {
270+
'TrainingInputMode': 'File'
271+
}
272+
if 'ContentType' in c:
273+
json_input_data_config[channel_name]['ContentType'] = c['ContentType']
249274

250275
_write_json_file(os.path.join(config_path, 'hyperparameters.json'), hyperparameters)
251276
_write_json_file(os.path.join(config_path, 'resourceconfig.json'), resource_config)
@@ -261,29 +286,13 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
261286
# mount the local directory to the container. For S3 Data we will download the S3 data
262287
# first.
263288
for channel in input_data_config:
264-
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
265-
uri = channel['DataSource']['S3DataSource']['S3Uri']
266-
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
267-
uri = channel['DataSource']['FileDataSource']['FileUri']
268-
else:
269-
raise ValueError('Need channel[\'DataSource\'] to have'
270-
' [\'S3DataSource\'] or [\'FileDataSource\']')
271-
272-
parsed_uri = urlparse(uri)
273-
key = parsed_uri.path.lstrip('/')
274-
289+
uri = channel['DataUri']
275290
channel_name = channel['ChannelName']
276291
channel_dir = os.path.join(data_dir, channel_name)
277292
os.mkdir(channel_dir)
278293

279-
if parsed_uri.scheme == 's3':
280-
bucket_name = parsed_uri.netloc
281-
sagemaker.utils.download_folder(bucket_name, key, channel_dir, self.sagemaker_session)
282-
elif parsed_uri.scheme == 'file':
283-
path = parsed_uri.path
284-
volumes.append(_Volume(path, channel=channel_name))
285-
else:
286-
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))
294+
data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
295+
volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
287296

288297
# If there is a training script directory and it is a local directory,
289298
# mount it to the container.
@@ -301,25 +310,20 @@ def _prepare_serving_volumes(self, model_location):
301310
volumes = []
302311
host = self.hosts[0]
303312
# Make the model available to the container. If this is a local file just mount it to
304-
# the container as a volume. If it is an S3 location download it and extract the tar file.
313+
# the container as a volume. If it is an S3 location, the DataSource will download it, we
314+
# just need to extract the tar file.
305315
host_dir = os.path.join(self.container_root, host)
306316
os.makedirs(host_dir)
307317

308-
if model_location.startswith('s3'):
309-
container_model_dir = os.path.join(self.container_root, host, 'model')
310-
os.makedirs(container_model_dir)
318+
model_data_source = sagemaker.local.data.get_data_source_instance(
319+
model_location, self.sagemaker_session)
311320

312-
parsed_uri = urlparse(model_location)
313-
filename = os.path.basename(parsed_uri.path)
314-
tar_location = os.path.join(container_model_dir, filename)
315-
sagemaker.utils.download_file(parsed_uri.netloc, parsed_uri.path, tar_location, self.sagemaker_session)
321+
for filename in model_data_source.get_file_list():
322+
if tarfile.is_tarfile(filename):
323+
with tarfile.open(filename) as tar:
324+
tar.extractall(path=model_data_source.get_root_dir())
316325

317-
if tarfile.is_tarfile(tar_location):
318-
with tarfile.open(tar_location) as tar:
319-
tar.extractall(path=container_model_dir)
320-
volumes.append(_Volume(container_model_dir, '/opt/ml/model'))
321-
else:
322-
volumes.append(_Volume(model_location, '/opt/ml/model'))
326+
volumes.append(_Volume(model_data_source.get_root_dir(), '/opt/ml/model'))
323327

324328
return volumes
325329

@@ -368,7 +372,6 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
368372
'networks': {
369373
'sagemaker-local': {'name': 'sagemaker-local'}
370374
}
371-
372375
}
373376

374377
docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME)
@@ -469,9 +472,15 @@ def _build_optml_volumes(self, host, subdirs):
469472

470473
return volumes
471474

472-
def _cleanup(self):
473-
# we don't need to cleanup anything at the moment
474-
pass
475+
def _cleanup(self, dirs_to_delete=None):
476+
if dirs_to_delete:
477+
for d in dirs_to_delete:
478+
_delete_tree(d)
479+
480+
# Free the container config files.
481+
for host in self.hosts:
482+
container_config_path = os.path.join(self.container_root, host)
483+
_delete_tree(container_config_path)
475484

476485

477486
class _HostingContainer(Thread):
@@ -610,7 +619,7 @@ def _aws_credentials(session):
610619
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key))
611620
]
612621
elif not _aws_credentials_available_in_metadata_service():
613-
logger.warn("Using the short-lived AWS credentials found in session. They might expire while running.")
622+
logger.warning("Using the short-lived AWS credentials found in session. They might expire while running.")
614623
return [
615624
'AWS_ACCESS_KEY_ID=%s' % (str(access_key)),
616625
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)),

src/sagemaker/local/local_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputData
7171
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
7272
training_job = _LocalTrainingJob(container)
7373
hyperparameters = kwargs['HyperParameters'] if 'HyperParameters' in kwargs else {}
74-
training_job.start(InputDataConfig, hyperparameters, TrainingJobName)
74+
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
7575

7676
LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
7777

0 commit comments

Comments
 (0)