Skip to content

Commit f441e4d

Browse files
author
Ignacio Quintero
committed
Fix aws#451 while we are touching local mode train()
1 parent 6c30025 commit f441e4d

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

src/sagemaker/local/image.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,15 @@ def write_config_files(self, host, hyperparameters, input_data_config):
262262
'hosts': self.hosts
263263
}
264264

265-
json_input_data_config = {
266-
c['ChannelName']: {'ContentType': 'application/octet-stream'} for c in input_data_config
267-
}
265+
print(input_data_config)
266+
json_input_data_config = {}
267+
for c in input_data_config:
268+
channel_name = c['ChannelName']
269+
json_input_data_config[channel_name] = {
270+
'TrainingInputMode': 'File'
271+
}
272+
if 'ContentType' in c:
273+
json_input_data_config[channel_name]['ContentType'] = c['ContentType']
268274

269275
_write_json_file(os.path.join(config_path, 'hyperparameters.json'), hyperparameters)
270276
_write_json_file(os.path.join(config_path, 'resourceconfig.json'), resource_config)

tests/unit/test_image.py

+54
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,60 @@ def test_write_config_file(LocalSession, tmpdir):
126126
assert channel['ChannelName'] in input_data_config_data
127127

128128

129+
@patch('sagemaker.local.local_session.LocalSession')
130+
def test_write_config_files_input_content_type(LocalSession, tmpdir):
131+
sagemaker_container = _SageMakerContainer('local', 1, 'my-image')
132+
sagemaker_container.container_root = str(tmpdir.mkdir('container-root'))
133+
host = 'algo-1'
134+
135+
sagemaker.local.image._create_config_file_directories(sagemaker_container.container_root, host)
136+
137+
container_root = sagemaker_container.container_root
138+
config_file_root = os.path.join(container_root, host, 'input', 'config')
139+
140+
input_data_config_file = os.path.join(config_file_root, 'inputdataconfig.json')
141+
142+
# write the config files, and then lets check they exist and have the right content.
143+
input_data_config = [
144+
{
145+
'ChannelName': 'channel_a',
146+
'DataUri': 'file:///tmp/source1',
147+
'ContentType': 'text/csv',
148+
'DataSource': {
149+
'FileDataSource': {
150+
'FileDataDistributionType': 'FullyReplicated',
151+
'FileUri': 'file:///tmp/source1'
152+
}
153+
}
154+
},
155+
{
156+
'ChannelName': 'channel_b',
157+
'DataUri': 's3://my-own-bucket/prefix',
158+
'DataSource': {
159+
'S3DataSource': {
160+
'S3DataDistributionType': 'FullyReplicated',
161+
'S3DataType': 'S3Prefix',
162+
'S3Uri': 's3://my-own-bucket/prefix'
163+
}
164+
}
165+
}
166+
]
167+
sagemaker_container.write_config_files(host, HYPERPARAMETERS, input_data_config)
168+
169+
assert os.path.exists(input_data_config_file)
170+
parsed_input_config = json.load(open(input_data_config_file))
171+
# Validate Input Data Config
172+
for channel in input_data_config:
173+
assert channel['ChannelName'] in parsed_input_config
174+
175+
# Channel A has a content type
176+
assert 'ContentType' in parsed_input_config['channel_a']
177+
assert parsed_input_config['channel_a']['ContentType'] == 'text/csv'
178+
179+
# Channel B does not have content type
180+
assert 'ContentType' not in parsed_input_config['channel_b']
181+
182+
129183
@patch('sagemaker.local.local_session.LocalSession')
130184
def test_retrieve_artifacts(LocalSession, tmpdir):
131185
sagemaker_container = _SageMakerContainer('local', 2, 'my-image')

0 commit comments

Comments
 (0)