Skip to content

Commit 2ae0aab

Browse files
author
Ignacio Quintero
committed
Added remaining data source and recordio split
1 parent 3a61975 commit 2ae0aab

File tree

10 files changed

+288
-102
lines changed

10 files changed

+288
-102
lines changed

src/sagemaker/amazon/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None):
153153
def read_records(file):
154154
"""Eagerly read a collection of amazon Record protobuf objects from file."""
155155
records = []
156-
for record_data in _read_recordio(file):
156+
for record_data in read_recordio(file):
157157
record = Record()
158158
record.ParseFromString(record_data)
159159
records.append(record)
@@ -183,7 +183,7 @@ def _write_recordio(f, data):
183183
f.write(padding[pad])
184184

185185

186-
def _read_recordio(f):
186+
def read_recordio(f):
187187
while(True):
188188
try:
189189
read_kmagic, = struct.unpack('I', f.read(4))

src/sagemaker/local/data.py

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,50 @@
1414

1515
import os
1616
import sys
17+
import tempfile
1718
from six.moves.urllib.parse import urlparse
1819

20+
from sagemaker.amazon.common import read_recordio
21+
from sagemaker.local.utils import download_folder
22+
from sagemaker.utils import get_config_value
23+
1924

2025
class DataSourceFactory(object):
2126

2227
@staticmethod
23-
def get_instance(data_source):
28+
def get_instance(data_source, sagemaker_session):
2429
parsed_uri = urlparse(data_source)
2530
if parsed_uri.scheme == 'file':
2631
return LocalFileDataSource(parsed_uri.path)
27-
else:
28-
# TODO Figure S3 and S3Manifest.
29-
return None
32+
elif parsed_uri.scheme == 's3':
33+
return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
34+
3035

3136
class DataSource(object):
3237

3338
def get_file_list(self):
3439
pass
3540

41+
def get_root_dir(self):
42+
pass
43+
3644

3745
class LocalFileDataSource(DataSource):
46+
"""
47+
Represents a data source within the local filesystem.
48+
"""
3849

3950
def __init__(self, root_path):
40-
self.root_path = root_path
41-
42-
def get_file_list(self):
51+
self.root_path = os.path.abspath(root_path)
4352
if not os.path.exists(self.root_path):
4453
raise RuntimeError('Invalid data source: %s Does not exist.' % self.root_path)
4554

46-
files = []
55+
def get_file_list(self):
56+
"""Retrieve the list of absolute paths to all the files in this data source.
57+
58+
Returns:
59+
List[string] List of absolute paths.
60+
"""
4761
if os.path.isdir(self.root_path):
4862
files = [os.path.join(self.root_path, f) for f in os.listdir(self.root_path)
4963
if os.path.isfile(os.path.join(self.root_path, f))]
@@ -52,12 +66,47 @@ def get_file_list(self):
5266

5367
return files
5468

69+
def get_root_dir(self):
70+
"""Retrieve the absolute path to the root directory of this data source.
71+
72+
Returns:
73+
string: absolute path to the root directory of this data source.
74+
"""
75+
if os.path.isdir(self.root_path):
76+
return self.root_path
77+
else:
78+
return os.path.dirname(self.root_path)
79+
80+
5581
class S3DataSource(DataSource):
56-
pass
82+
"""Defines a data source given by a bucket and s3 prefix. The contents will be downloaded
83+
and then processed as local data.
84+
"""
85+
86+
def __init__(self, bucket, prefix, sagemaker_session):
87+
"""Create an S3DataSource instance
88+
89+
Args:
90+
bucket (str): s3 bucket name
91+
prefix (str): s3 prefix path to the data
92+
sagemaker_session (sagemaker.Session): a sagemaker_session with the desired settings to talk to s3
93+
94+
"""
95+
96+
# Create a temporary dir to store the S3 contents
97+
root_dir = get_config_value('local.container_root', sagemaker_session.config)
98+
if root_dir:
99+
root_dir = os.path.abspath(root_dir)
57100

101+
working_dir = tempfile.mkdtemp(dir=root_dir)
102+
download_folder(bucket, prefix, working_dir, sagemaker_session)
103+
self.files = LocalFileDataSource(working_dir)
58104

59-
class S3ManifestDataSource(DataSource):
60-
pass
105+
def get_file_list(self):
106+
return self.files.get_file_list()
107+
108+
def get_root_dir(self):
109+
return self.files.get_root_dir()
61110

62111

63112
class SplitterFactory(object):
@@ -79,23 +128,37 @@ class Splitter(object):
79128
def split(self, file):
80129
pass
81130

131+
82132
class NoneSplitter(Splitter):
133+
"""Does not split records, essentially reads the whole file.
134+
"""
83135

84136
def split(self, file):
85137
with open(file, 'r') as f:
86138
yield f.read()
87139

140+
88141
class LineSplitter(Splitter):
142+
"""Split records by new line.
143+
144+
"""
89145

90146
def split(self, file):
91147
with open(file, 'r') as f:
92148
for line in f:
93149
yield line
94150

151+
95152
class RecordIOSplitter(Splitter):
153+
"""Split using Amazon Recordio.
154+
155+
Not useful for string content.
96156
157+
"""
97158
def split(self, file):
98-
pass
159+
with open(file, 'rb') as f:
160+
for record in read_recordio(f):
161+
yield record
99162

100163

101164
class BatchStrategyFactory(object):
@@ -109,13 +172,19 @@ def get_instance(strategy, splitter):
109172
else:
110173
return None
111174

175+
112176
class BatchStrategy(object):
113177

114178
def pad(self, file, size):
115179
pass
116180

181+
117182
class MultiRecordStrategy(BatchStrategy):
183+
"""Feed multiple records at a time for batch inference.
118184
185+
Will group up as many records as possible within the payload specified.
186+
187+
"""
119188
def __init__(self, splitter):
120189
self.splitter = splitter
121190

@@ -133,7 +202,10 @@ def pad(self, file, size=6):
133202

134203

135204
class SingleRecordStrategy(BatchStrategy):
205+
"""Feed a single record at a time for batch inference.
136206
207+
If a single record does not fit within the payload specified it will throw a Runtime error.
208+
"""
137209
def __init__(self, splitter):
138210
self.splitter = splitter
139211

@@ -144,14 +216,34 @@ def pad(self, file, size=6):
144216

145217

146218
def _payload_size_within_limit(payload, size):
219+
"""
220+
221+
Args:
222+
payload:
223+
size:
224+
225+
Returns:
226+
227+
"""
147228
size_in_bytes = size * 1024 * 1024
148229
if size == 0:
149230
return True
150231
else:
151-
print('size_of_payload: %s > %s' % (sys.getsizeof(payload), size_in_bytes))
152232
return sys.getsizeof(payload) < size_in_bytes
153233

234+
154235
def _validate_payload_size(payload, size):
155-
if not _payload_size_within_limit(payload, size):
156-
raise RuntimeError('Record is larger than %sMB. Please increase your max_payload' % size)
157-
return True
236+
"""Check if a payload is within the size in MB threshold. Raise an exception otherwise.
237+
238+
Args:
239+
payload: data that will be checked
240+
size (int): max size in MB
241+
242+
Returns (bool): True if within bounds. if size=0 it will always return True
243+
Raises:
244+
RuntimeError: If the payload is larger a runtime error is thrown.
245+
"""
246+
247+
if not _payload_size_within_limit(payload, size):
248+
raise RuntimeError('Record is larger than %sMB. Please increase your max_payload' % size)
249+
return True

src/sagemaker/local/entities.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import tempfile
1919
import time
2020
import urllib3
21-
from six.moves.urllib.parse import urlparse
2221

2322
from sagemaker.local.data import BatchStrategyFactory, DataSourceFactory, SplitterFactory
2423
from sagemaker.local.image import _SageMakerContainer
24+
from sagemaker.local.utils import copy_directory_structure, move_to_destination
2525
from sagemaker.utils import get_config_value
2626

2727
logger = logging.getLogger(__name__)
@@ -117,12 +117,13 @@ def start(self, input_data, output_data, **kwargs):
117117
def _batch_inference(self, input_data, output_data, **kwargs):
118118
# TODO - Figure if we should pass FileDataSource here instead. Ideally not but the semantics
119119
# are just weird.
120+
print(output_data)
120121
input_path = input_data['DataSource']['S3DataSource']['S3Uri']
121122

122123
# Transform the input data to feed the serving container. We need to first gather the files
123124
# from S3 or Local FileSystem. Split them as required (Line, RecordIO, None) and finally batch them
124125
# according to the batch strategy and limit the request size.
125-
data_source = DataSourceFactory.get_instance(input_path)
126+
data_source = DataSourceFactory.get_instance(input_path, self.local_session)
126127
split_type = input_data['SplitType'] if 'SplitType' in input_data else None
127128
splitter = SplitterFactory.get_instance(split_type)
128129

@@ -131,32 +132,33 @@ def _batch_inference(self, input_data, output_data, **kwargs):
131132
if 'BatchStrategy' in kwargs:
132133
batch_strategy = kwargs['BatchStrategy']
133134

134-
135135
max_payload = 6
136136
if 'MaxPayloadInMB' in kwargs:
137137
max_payload = int(kwargs['MaxPayloadInMB'])
138138

139139
final_data = BatchStrategyFactory.get_instance(batch_strategy, splitter)
140140

141-
142141
# Output settings
143142
accept = output_data['Accept'] if 'Accept' in output_data else None
144143
# TODO - add a warning that we don't support KMS in Local Mode.
145144

146-
147145
# Root dir to use for intermediate data location. To make things simple we will write here regardless
148146
# of the final destination. At the end the files will either be moved or uploaded to S3 and deleted.
149147
root_dir = get_config_value('local.container_root', self.local_session.config)
150148
if root_dir:
151149
root_dir = os.path.abspath(root_dir)
152150

151+
working_dir = tempfile.mkdtemp(dir=root_dir)
152+
dataset_dir = data_source.get_root_dir()
153153

154-
out_fd, out_path = tempfile.mkstemp(dir=root_dir)
154+
for file in data_source.get_file_list():
155155

156-
output_files = {}
157-
with os.fdopen(out_fd, 'w') as f:
158-
for file in data_source.get_file_list():
156+
relative_path = os.path.dirname(os.path.relpath(file, dataset_dir))
157+
filename = os.path.basename(file)
158+
copy_directory_structure(working_dir, relative_path)
159+
destination_path = os.path.join(working_dir, relative_path, filename + '.out')
159160

161+
with open(destination_path, 'w') as f:
160162
for item in final_data.pad(file, max_payload):
161163
# call the container and add the result to inference.
162164
response = self.local_session.sagemaker_runtime_client.invoke_endpoint(
@@ -166,12 +168,16 @@ def _batch_inference(self, input_data, output_data, **kwargs):
166168
data = response_body.read()
167169
response_body.close()
168170
print('data: %s' % data)
169-
# TODO - AssembleWith determines if we add a new line or not.
170171
f.write(data)
172+
if 'AssembleWith' in output_data and output_data['AssembleWith'] == 'Line':
173+
f.write('\n')
171174

172-
print(out_path)
175+
print(working_dir)
176+
move_to_destination(working_dir, output_data['S3OutputPath'], self.local_session)
177+
print(output_data['S3OutputPath'])
173178
self.container.stop_serving()
174179

180+
175181
class _LocalModel(object):
176182

177183
def __init__(self, model_name, primary_container):
@@ -244,7 +250,6 @@ def serve(self):
244250
# the container is running and it passed the healthcheck status is now InService
245251
self.state = _LocalEndpoint._IN_SERVICE
246252

247-
248253
def stop(self):
249254
if self.container:
250255
self.container.stop_serving()
@@ -269,7 +274,7 @@ def _wait_for_serving_container(serving_port):
269274
while True:
270275
i += 1
271276
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
272-
raise RuntimeError('Giving up, endpoint: %s didn\'t launch correctly' % self.name)
277+
raise RuntimeError('Giving up, endpoint didn\'t launch correctly')
273278

274279
logger.info('Checking if serving container is up, attempt: %s' % i)
275280
try:

0 commit comments

Comments
 (0)