Skip to content

Commit c1f1ab9

Browse files
Valorumiquintero
authored andcommitted
Add support for file:// URI as the input for LocalMode training data (#168)
* Local mode support for file:// URI as the input for training data, bypassing uploading to/downloading from S3.
1 parent d2f018e commit c1f1ab9

File tree

8 files changed

+193
-23
lines changed

8 files changed

+193
-23
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
CHANGELOG
33
=========
44

5+
56
1.2.dev5
67
========
78

89
* bug-fix: Change module names to string type in __all__
10+
* feature: Local Mode: add support for local training data using file://
911

1012
1.2.4
1113
=====

README.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Local Mode
148148
~~~~~~~~~~
149149

150150
The SageMaker Python SDK now supports local mode, which allows you to create TensorFlow, MXNet and BYO estimators and
151-
deploy to your local environment.  This is a great way to test your deep learning script before running in
151+
deploy to your local environment. This is a great way to test your deep learning script before running in
152152
SageMaker's managed training or hosting environments.
153153

154154
We can take the example in `Estimator Usage <#estimator-usage>`__ , and use either ``local`` or ``local_gpu`` as the
@@ -166,6 +166,9 @@ instance type.
166166
# In Local Mode, fit will pull the MXNet container docker image and run it locally
167167
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
168168
169+
# Alternatively, you can train using data in your local file system. This is only supported in Local mode.
170+
mxnet_estimator.fit('file:///tmp/my_training_data')
171+
169172
# Deploys the model that was generated by fit() to local endpoint in a container
170173
mxnet_predictor = mxnet_estimator.deploy(initial_instance_count=1, instance_type='local')
171174
@@ -184,7 +187,7 @@ For detailed examples of running docker in local mode, see:
184187
A few important notes:
185188

186189
- Only one local mode endpoint can be running at a time
187-
- Since the data are pulled from S3 to your local environment, please ensure you have sufficient space.
190+
- If you are using s3 data as input, it will be pulled from S3 to your local environment, please ensure you have sufficient space.
188191
- If you run into problems, this is often due to different docker containers conflicting.  Killing these containers and re-running often solves your problems.
189192
- Local Mode requires docker-compose and `nvidia-docker2 <https://github.com/NVIDIA/nvidia-docker>`__ for ``local_gpu``.
190193
- Distributed training is not yet supported for ``local_gpu``.

src/sagemaker/estimator.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from sagemaker.fw_utils import tar_and_upload_dir
2222
from sagemaker.fw_utils import parse_s3_url
2323
from sagemaker.fw_utils import UploadedCode
24-
from sagemaker.local.local_session import LocalSession
24+
from sagemaker.local.local_session import LocalSession, file_input
25+
2526
from sagemaker.model import Model
2627
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
2728
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
29+
2830
from sagemaker.predictor import RealTimePredictor
2931
from sagemaker.session import Session
3032
from sagemaker.session import s3_input
@@ -321,6 +323,13 @@ def start_new(cls, estimator, inputs):
321323
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
322324
"""
323325

326+
local_mode = estimator.local_mode
327+
328+
# Allow file:// input only in local mode
329+
if isinstance(inputs, str) and inputs.startswith('file://'):
330+
if not local_mode:
331+
raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.')
332+
324333
input_config = _TrainingJob._format_inputs_to_input_config(inputs)
325334
role = estimator.sagemaker_session.expand_role(estimator.role)
326335
output_config = _TrainingJob._prepare_output_config(estimator.output_path, estimator.output_kms_key)
@@ -343,12 +352,14 @@ def start_new(cls, estimator, inputs):
343352
def _format_inputs_to_input_config(inputs):
344353
input_dict = {}
345354
if isinstance(inputs, string_types):
346-
input_dict['training'] = _TrainingJob._format_s3_uri_input(inputs)
355+
input_dict['training'] = _TrainingJob._format_string_uri_input(inputs)
347356
elif isinstance(inputs, s3_input):
348357
input_dict['training'] = inputs
358+
elif isinstance(input, file_input):
359+
input_dict['training'] = inputs
349360
elif isinstance(inputs, dict):
350361
for k, v in inputs.items():
351-
input_dict[k] = _TrainingJob._format_s3_uri_input(v)
362+
input_dict[k] = _TrainingJob._format_string_uri_input(v)
352363
else:
353364
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
354365

@@ -360,15 +371,21 @@ def _format_inputs_to_input_config(inputs):
360371
return channels
361372

362373
@staticmethod
363-
def _format_s3_uri_input(input):
374+
def _format_string_uri_input(input):
364375
if isinstance(input, str):
365-
if not input.startswith('s3://'):
366-
raise ValueError('Training input data must be a valid S3 URI and must start with "s3://"')
367-
return s3_input(input)
368-
if isinstance(input, s3_input):
376+
if input.startswith('s3://'):
377+
return s3_input(input)
378+
elif input.startswith('file://'):
379+
return file_input(input)
380+
else:
381+
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
382+
'"file://"')
383+
elif isinstance(input, s3_input):
384+
return input
385+
elif isinstance(input, file_input):
369386
return input
370387
else:
371-
raise ValueError('Cannot format input {}. Expecting one of str or s3_input'.format(input))
388+
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input))
372389

373390
@staticmethod
374391
def _prepare_output_config(s3_path, kms_key_id):

src/sagemaker/local/image.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def train(self, input_data_config, hyperparameters):
9393
# mount the local directory to the container. For S3 Data we will download the S3 data
9494
# first.
9595
for channel in input_data_config:
96-
uri = channel['DataSource']['S3DataSource']['S3Uri']
96+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
97+
uri = channel['DataSource']['S3DataSource']['S3Uri']
98+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
99+
uri = channel['DataSource']['FileDataSource']['FileUri']
100+
else:
101+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
102+
97103
parsed_uri = urlparse(uri)
98104
key = parsed_uri.path.lstrip('/')
99105

@@ -104,8 +110,11 @@ def train(self, input_data_config, hyperparameters):
104110
if parsed_uri.scheme == 's3':
105111
bucket_name = parsed_uri.netloc
106112
self._download_folder(bucket_name, key, channel_dir)
113+
elif parsed_uri.scheme == 'file':
114+
path = parsed_uri.path
115+
volumes.append(_Volume(path, channel=channel_name))
107116
else:
108-
volumes.append(_Volume(uri, channel=channel_name))
117+
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))
109118

110119
# Create the configuration files for each container that we will create
111120
# Each container will map the additional local volumes (if any).

src/sagemaker/local/local_session.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,22 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, RoleArn,
5656
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
5757

5858
for channel in InputDataConfig:
59-
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
59+
60+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
61+
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
62+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
63+
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
64+
else:
65+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
66+
6067
if data_distribution != 'FullyReplicated':
6168
raise RuntimeError("DataDistribution: %s is not currently supported in Local Mode" %
6269
data_distribution)
6370

6471
self.s3_model_artifacts = self.train_container.train(InputDataConfig, HyperParameters)
6572

6673
def describe_training_job(self, TrainingJobName):
67-
"""Describe a local traininig job.
74+
"""Describe a local training job.
6875
6976
Args:
7077
TrainingJobName (str): Not used in this implmentation.
@@ -171,3 +178,26 @@ def logs_for_job(self, job_name, wait=False, poll=5):
171178
# override logs_for_job() as it doesn't need to perform any action
172179
# on local mode.
173180
pass
181+
182+
183+
class file_input(object):
184+
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode.
185+
186+
Attributes:
187+
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``FileDataSource``.
188+
"""
189+
190+
def __init__(self, fileUri, content_type=None):
191+
"""Create a definition for input data used by an SageMaker training job in local mode.
192+
"""
193+
self.config = {
194+
'DataSource': {
195+
'FileDataSource': {
196+
'FileDataDistributionType': 'FullyReplicated',
197+
'FileUri': fileUri
198+
}
199+
}
200+
}
201+
202+
if content_type is not None:
203+
self.config['ContentType'] = content_type

tests/unit/test_estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_sagemaker_s3_uri_invalid(sagemaker_session):
101101
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
102102
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
103103
t.fit('thisdoesntstartwiths3')
104-
assert 'must be a valid S3 URI' in str(error)
104+
assert 'must be a valid S3 or FILE URI' in str(error)
105105

106106

107107
@patch('time.strftime', return_value=TIMESTAMP)
@@ -427,9 +427,8 @@ def test_unsupported_type():
427427

428428

429429
def test_unsupported_type_in_dict():
430-
with pytest.raises(ValueError) as error:
430+
with pytest.raises(ValueError):
431431
_TrainingJob._format_inputs_to_input_config({'a': 66})
432-
assert 'Expecting one of str or s3_input' in str(error)
433432

434433

435434
#################################################################################

tests/unit/test_image.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
{
2828
'ChannelName': 'a',
2929
'DataSource': {
30-
'S3DataSource': {
31-
'S3DataDistributionType': 'FullyReplicated',
32-
'S3DataType': 'S3Prefix',
33-
'S3Uri': '/tmp/source1'
30+
'FileDataSource': {
31+
'FileDataDistributionType': 'FullyReplicated',
32+
'FileUri': 'file:///tmp/source1'
3433
}
3534
}
3635
},

tests/unit/test_local_session.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,26 @@ def test_create_training_job(train, LocalSession):
3535
image = "my-docker-image:1.0"
3636

3737
algo_spec = {'TrainingImage': image}
38-
input_data_config = {}
38+
input_data_config = [
39+
{
40+
'ChannelName': 'a',
41+
'DataSource': {
42+
'S3DataSource': {
43+
'S3DataDistributionType': 'FullyReplicated',
44+
'S3Uri': 's3://my_bucket/tmp/source1'
45+
}
46+
}
47+
},
48+
{
49+
'ChannelName': 'b',
50+
'DataSource': {
51+
'FileDataSource': {
52+
'FileDataDistributionType': 'FullyReplicated',
53+
'FileUri': 'file:///tmp/source1'
54+
}
55+
}
56+
}
57+
]
3958
output_data_config = {}
4059
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
4160
hyperparameters = {'a': 1, 'b': 'bee'}
@@ -61,6 +80,67 @@ def test_create_training_job(train, LocalSession):
6180
assert response['ModelArtifacts']['S3ModelArtifacts'] == expected['ModelArtifacts']['S3ModelArtifacts']
6281

6382

83+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
84+
@patch('sagemaker.local.local_session.LocalSession')
85+
def test_create_training_job_invalid_data_source(train, LocalSession):
86+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
87+
88+
instance_count = 2
89+
image = "my-docker-image:1.0"
90+
91+
algo_spec = {'TrainingImage': image}
92+
93+
# InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only
94+
# valid Data Sources. We expect a ValueError if we pass this input data config.
95+
input_data_config = [{
96+
'ChannelName': 'a',
97+
'DataSource': {
98+
'InvalidDataSource': {
99+
'FileDataDistributionType': 'FullyReplicated',
100+
'FileUri': 'ftp://myserver.com/tmp/source1'
101+
}
102+
}
103+
}]
104+
105+
output_data_config = {}
106+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
107+
hyperparameters = {'a': 1, 'b': 'bee'}
108+
109+
with pytest.raises(ValueError):
110+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
111+
output_data_config, resource_config, None, hyperparameters)
112+
113+
114+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
115+
@patch('sagemaker.local.local_session.LocalSession')
116+
def test_create_training_job_not_fully_replicated(train, LocalSession):
117+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
118+
119+
instance_count = 2
120+
image = "my-docker-image:1.0"
121+
122+
algo_spec = {'TrainingImage': image}
123+
124+
# Local Mode only supports FullyReplicated as Data Distribution type.
125+
input_data_config = [{
126+
'ChannelName': 'a',
127+
'DataSource': {
128+
'S3DataSource': {
129+
'S3DataDistributionType': 'ShardedByS3Key',
130+
'S3Uri': 's3://my_bucket/tmp/source1'
131+
}
132+
}
133+
}]
134+
135+
output_data_config = {}
136+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
137+
hyperparameters = {'a': 1, 'b': 'bee'}
138+
139+
with pytest.raises(RuntimeError):
140+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
141+
output_data_config, resource_config, None, hyperparameters)
142+
143+
64144
@patch('sagemaker.local.local_session.LocalSession')
65145
def test_create_model(LocalSession):
66146
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
@@ -130,3 +210,34 @@ def test_create_endpoint_fails(serve, request, LocalSession):
130210

131211
with pytest.raises(RuntimeError):
132212
local_sagemaker_client.create_endpoint('my-endpoint', 'some-endpoint-config')
213+
214+
215+
def test_file_input_all_defaults():
216+
prefix = 'pre'
217+
actual = sagemaker.local.local_session.file_input(fileUri=prefix)
218+
expected = \
219+
{
220+
'DataSource': {
221+
'FileDataSource': {
222+
'FileDataDistributionType': 'FullyReplicated',
223+
'FileUri': prefix
224+
}
225+
}
226+
}
227+
assert actual.config == expected
228+
229+
230+
def test_file_input_content_type():
231+
prefix = 'pre'
232+
actual = sagemaker.local.local_session.file_input(fileUri=prefix, content_type='text/csv')
233+
expected = \
234+
{
235+
'DataSource': {
236+
'FileDataSource': {
237+
'FileDataDistributionType': 'FullyReplicated',
238+
'FileUri': prefix
239+
}
240+
},
241+
'ContentType': 'text/csv'
242+
}
243+
assert actual.config == expected

0 commit comments

Comments
 (0)