14
14
15
15
import os
16
16
import sys
17
+ import tempfile
17
18
from six .moves .urllib .parse import urlparse
18
19
20
+ from sagemaker .amazon .common import read_recordio
21
+ from sagemaker .local .utils import download_folder
22
+ from sagemaker .utils import get_config_value
23
+
19
24
20
25
class DataSourceFactory (object ):
21
26
22
27
@staticmethod
23
- def get_instance (data_source ):
28
+ def get_instance (data_source , sagemaker_session ):
24
29
parsed_uri = urlparse (data_source )
25
30
if parsed_uri .scheme == 'file' :
26
31
return LocalFileDataSource (parsed_uri .path )
27
- else :
28
- # TODO Figure S3 and S3Manifest.
29
- return None
32
+ elif parsed_uri . scheme == 's3' :
33
+ return S3DataSource ( parsed_uri . netloc , parsed_uri . path , sagemaker_session )
34
+
30
35
31
36
class DataSource (object ):
32
37
33
38
def get_file_list (self ):
34
39
pass
35
40
41
+ def get_root_dir (self ):
42
+ pass
43
+
36
44
37
45
class LocalFileDataSource (DataSource ):
46
+ """
47
+ Represents a data source within the local filesystem.
48
+ """
38
49
39
50
def __init__ (self , root_path ):
40
- self .root_path = root_path
41
-
42
- def get_file_list (self ):
51
+ self .root_path = os .path .abspath (root_path )
43
52
if not os .path .exists (self .root_path ):
44
53
raise RuntimeError ('Invalid data source: %s Does not exist.' % self .root_path )
45
54
46
- files = []
55
+ def get_file_list (self ):
56
+ """Retrieve the list of absolute paths to all the files in this data source.
57
+
58
+ Returns:
59
+ List[string] List of absolute paths.
60
+ """
47
61
if os .path .isdir (self .root_path ):
48
62
files = [os .path .join (self .root_path , f ) for f in os .listdir (self .root_path )
49
63
if os .path .isfile (os .path .join (self .root_path , f ))]
@@ -52,12 +66,47 @@ def get_file_list(self):
52
66
53
67
return files
54
68
69
+ def get_root_dir (self ):
70
+ """Retrieve the absolute path to the root directory of this data source.
71
+
72
+ Returns:
73
+ string: absolute path to the root directory of this data source.
74
+ """
75
+ if os .path .isdir (self .root_path ):
76
+ return self .root_path
77
+ else :
78
+ return os .path .dirname (self .root_path )
79
+
80
+
55
81
class S3DataSource (DataSource ):
56
- pass
82
+ """Defines a data source given by a bucket and s3 prefix. The contents will be downloaded
83
+ and then processed as local data.
84
+ """
85
+
86
+ def __init__ (self , bucket , prefix , sagemaker_session ):
87
+ """Create an S3DataSource instance
88
+
89
+ Args:
90
+ bucket (str): s3 bucket name
91
+ prefix (str): s3 prefix path to the data
92
+ sagemaker_session (sagemaker.Session): a sagemaker_session with the desired settings to talk to s3
93
+
94
+ """
95
+
96
+ # Create a temporary dir to store the S3 contents
97
+ root_dir = get_config_value ('local.container_root' , sagemaker_session .config )
98
+ if root_dir :
99
+ root_dir = os .path .abspath (root_dir )
57
100
101
+ working_dir = tempfile .mkdtemp (dir = root_dir )
102
+ download_folder (bucket , prefix , working_dir , sagemaker_session )
103
+ self .files = LocalFileDataSource (working_dir )
58
104
59
- class S3ManifestDataSource (DataSource ):
60
- pass
105
+ def get_file_list (self ):
106
+ return self .files .get_file_list ()
107
+
108
+ def get_root_dir (self ):
109
+ return self .files .get_root_dir ()
61
110
62
111
63
112
class SplitterFactory (object ):
@@ -79,23 +128,37 @@ class Splitter(object):
79
128
def split (self , file ):
80
129
pass
81
130
131
+
82
132
class NoneSplitter (Splitter ):
133
+ """Does not split records, essentially reads the whole file.
134
+ """
83
135
84
136
def split (self , file ):
85
137
with open (file , 'r' ) as f :
86
138
yield f .read ()
87
139
140
+
88
141
class LineSplitter (Splitter ):
142
+ """Split records by new line.
143
+
144
+ """
89
145
90
146
def split (self , file ):
91
147
with open (file , 'r' ) as f :
92
148
for line in f :
93
149
yield line
94
150
151
+
95
152
class RecordIOSplitter (Splitter ):
153
+ """Split using Amazon Recordio.
154
+
155
+ Not useful for string content.
96
156
157
+ """
97
158
def split (self , file ):
98
- pass
159
+ with open (file , 'rb' ) as f :
160
+ for record in read_recordio (f ):
161
+ yield record
99
162
100
163
101
164
class BatchStrategyFactory (object ):
@@ -109,13 +172,19 @@ def get_instance(strategy, splitter):
109
172
else :
110
173
return None
111
174
175
+
112
176
class BatchStrategy (object ):
113
177
114
178
def pad (self , file , size ):
115
179
pass
116
180
181
+
117
182
class MultiRecordStrategy (BatchStrategy ):
183
+ """Feed multiple records at a time for batch inference.
118
184
185
+ Will group up as many records as possible within the payload specified.
186
+
187
+ """
119
188
def __init__ (self , splitter ):
120
189
self .splitter = splitter
121
190
@@ -133,7 +202,10 @@ def pad(self, file, size=6):
133
202
134
203
135
204
class SingleRecordStrategy (BatchStrategy ):
205
+ """Feed a single record at a time for batch inference.
136
206
207
+ If a single record does not fit within the payload specified it will throw a Runtime error.
208
+ """
137
209
def __init__ (self , splitter ):
138
210
self .splitter = splitter
139
211
@@ -144,14 +216,34 @@ def pad(self, file, size=6):
144
216
145
217
146
218
def _payload_size_within_limit (payload , size ):
219
+ """
220
+
221
+ Args:
222
+ payload:
223
+ size:
224
+
225
+ Returns:
226
+
227
+ """
147
228
size_in_bytes = size * 1024 * 1024
148
229
if size == 0 :
149
230
return True
150
231
else :
151
- print ('size_of_payload: %s > %s' % (sys .getsizeof (payload ), size_in_bytes ))
152
232
return sys .getsizeof (payload ) < size_in_bytes
153
233
234
+
154
235
def _validate_payload_size (payload , size ):
155
- if not _payload_size_within_limit (payload , size ):
156
- raise RuntimeError ('Record is larger than %sMB. Please increase your max_payload' % size )
157
- return True
236
+ """Check if a payload is within the size in MB threshold. Raise an exception otherwise.
237
+
238
+ Args:
239
+ payload: data that will be checked
240
+ size (int): max size in MB
241
+
242
+ Returns (bool): True if within bounds. if size=0 it will always return True
243
+ Raises:
244
+ RuntimeError: If the payload is larger a runtime error is thrown.
245
+ """
246
+
247
+ if not _payload_size_within_limit (payload , size ):
248
+ raise RuntimeError ('Record is larger than %sMB. Please increase your max_payload' % size )
249
+ return True
0 commit comments