33
33
import yaml
34
34
35
35
import sagemaker
36
+ import sagemaker .local .data
37
+ import sagemaker .local .utils
38
+ import sagemaker .utils
36
39
37
40
CONTAINER_PREFIX = 'algo'
38
41
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
@@ -78,7 +81,7 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
78
81
self .container_root = None
79
82
self .container = None
80
83
81
- def train (self , input_data_config , hyperparameters , job_name ):
84
+ def train (self , input_data_config , output_data_config , hyperparameters , job_name ):
82
85
"""Run a training job locally using docker-compose.
83
86
Args:
84
87
input_data_config (dict): The Input Data Configuration, this contains data such as the
@@ -126,23 +129,17 @@ def train(self, input_data_config, hyperparameters, job_name):
126
129
msg = "Failed to run: %s, %s" % (compose_command , str (e ))
127
130
raise RuntimeError (msg )
128
131
129
- s3_artifacts = self .retrieve_artifacts (compose_data )
132
+ artifacts = self .retrieve_artifacts (compose_data , output_data_config , job_name )
130
133
131
134
# free up the training data directory as it may contain
132
135
# lots of data downloaded from S3. This doesn't delete any local
133
136
# data that was just mounted to the container.
134
- _delete_tree (data_dir )
135
- _delete_tree (shared_dir )
136
- # Also free the container config files.
137
- for host in self .hosts :
138
- container_config_path = os .path .join (self .container_root , host )
139
- _delete_tree (container_config_path )
140
-
141
- self ._cleanup ()
142
- # Print our Job Complete line to have a simmilar experience to training on SageMaker where you
137
+ dirs_to_delete = [data_dir , shared_dir ]
138
+ self ._cleanup (dirs_to_delete )
139
+ # Print our Job Complete line to have a similar experience to training on SageMaker where you
143
140
# see this line at the end.
144
141
print ('===== Job Complete =====' )
145
- return s3_artifacts
142
+ return artifacts
146
143
147
144
def serve (self , model_dir , environment ):
148
145
"""Host a local endpoint using docker-compose.
@@ -188,7 +185,7 @@ def stop_serving(self):
188
185
# for serving we can delete everything in the container root.
189
186
_delete_tree (self .container_root )
190
187
191
- def retrieve_artifacts (self , compose_data ):
188
+ def retrieve_artifacts (self , compose_data , output_data_config , job_name ):
192
189
"""Get the model artifacts from all the container nodes.
193
190
194
191
Used after training completes to gather the data from all the individual containers. As the
@@ -201,26 +198,49 @@ def retrieve_artifacts(self, compose_data):
201
198
Returns: Local path to the collected model artifacts.
202
199
203
200
"""
204
- # Grab the model artifacts from all the Nodes.
205
- s3_artifacts = os .path .join (self .container_root , 's3_artifacts' )
206
- os .mkdir (s3_artifacts )
201
+ # We need a directory to store the artfiacts from all the nodes
202
+ # and another one to contained the compressed final artifacts
203
+ artifacts = os .path .join (self .container_root , 'artifacts' )
204
+ compressed_artifacts = os .path .join (self .container_root , 'compressed_artifacts' )
205
+ os .mkdir (artifacts )
206
+
207
+ model_artifacts = os .path .join (artifacts , 'model' )
208
+ output_artifacts = os .path .join (artifacts , 'output' )
207
209
208
- s3_model_artifacts = os .path .join (s3_artifacts , 'model' )
209
- s3_output_artifacts = os .path .join (s3_artifacts , 'output' )
210
- os .mkdir (s3_model_artifacts )
211
- os .mkdir (s3_output_artifacts )
210
+ artifact_dirs = [model_artifacts , output_artifacts , compressed_artifacts ]
211
+ for d in artifact_dirs :
212
+ os .mkdir (d )
212
213
214
+ # Gather the artifacts from all nodes into artifacts/model and artifacts/output
213
215
for host in self .hosts :
214
216
volumes = compose_data ['services' ][str (host )]['volumes' ]
215
-
216
217
for volume in volumes :
217
218
host_dir , container_dir = volume .split (':' )
218
219
if container_dir == '/opt/ml/model' :
219
- sagemaker .local .utils .recursive_copy (host_dir , s3_model_artifacts )
220
+ sagemaker .local .utils .recursive_copy (host_dir , model_artifacts )
220
221
elif container_dir == '/opt/ml/output' :
221
- sagemaker .local .utils .recursive_copy (host_dir , s3_output_artifacts )
222
+ sagemaker .local .utils .recursive_copy (host_dir , output_artifacts )
223
+
224
+ # Tar Artifacts -> model.tar.gz and output.tar.gz
225
+ model_files = [os .path .join (model_artifacts , name ) for name in os .listdir (model_artifacts )]
226
+ output_files = [os .path .join (output_artifacts , name ) for name in os .listdir (output_artifacts )]
227
+ sagemaker .utils .create_tar_file (model_files , os .path .join (compressed_artifacts , 'model.tar.gz' ))
228
+ sagemaker .utils .create_tar_file (output_files , os .path .join (compressed_artifacts , 'output.tar.gz' ))
229
+
230
+ if output_data_config ['S3OutputPath' ] == '' :
231
+ output_data = 'file://%s' % compressed_artifacts
232
+ else :
233
+ # Now we just need to move the compressed artifacts to wherever they are required
234
+ output_data = sagemaker .local .utils .move_to_destination (
235
+ compressed_artifacts ,
236
+ output_data_config ['S3OutputPath' ],
237
+ job_name ,
238
+ self .sagemaker_session )
239
+
240
+ _delete_tree (model_artifacts )
241
+ _delete_tree (output_artifacts )
222
242
223
- return s3_model_artifacts
243
+ return os . path . join ( output_data , 'model.tar.gz' )
224
244
225
245
def write_config_files (self , host , hyperparameters , input_data_config ):
226
246
"""Write the config files for the training containers.
@@ -235,7 +255,6 @@ def write_config_files(self, host, hyperparameters, input_data_config):
235
255
Returns: None
236
256
237
257
"""
238
-
239
258
config_path = os .path .join (self .container_root , host , 'input' , 'config' )
240
259
241
260
resource_config = {
@@ -261,29 +280,13 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
261
280
# mount the local directory to the container. For S3 Data we will download the S3 data
262
281
# first.
263
282
for channel in input_data_config :
264
- if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
265
- uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
266
- elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
267
- uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
268
- else :
269
- raise ValueError ('Need channel[\' DataSource\' ] to have'
270
- ' [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
271
-
272
- parsed_uri = urlparse (uri )
273
- key = parsed_uri .path .lstrip ('/' )
274
-
283
+ uri = channel ['DataUri' ]
275
284
channel_name = channel ['ChannelName' ]
276
285
channel_dir = os .path .join (data_dir , channel_name )
277
286
os .mkdir (channel_dir )
278
287
279
- if parsed_uri .scheme == 's3' :
280
- bucket_name = parsed_uri .netloc
281
- sagemaker .utils .download_folder (bucket_name , key , channel_dir , self .sagemaker_session )
282
- elif parsed_uri .scheme == 'file' :
283
- path = parsed_uri .path
284
- volumes .append (_Volume (path , channel = channel_name ))
285
- else :
286
- raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
288
+ data_source = sagemaker .local .data .get_data_source_instance (uri , self .sagemaker_session )
289
+ volumes .append (_Volume (data_source .get_root_dir (), channel = channel_name ))
287
290
288
291
# If there is a training script directory and it is a local directory,
289
292
# mount it to the container.
@@ -301,25 +304,20 @@ def _prepare_serving_volumes(self, model_location):
301
304
volumes = []
302
305
host = self .hosts [0 ]
303
306
# Make the model available to the container. If this is a local file just mount it to
304
- # the container as a volume. If it is an S3 location download it and extract the tar file.
307
+ # the container as a volume. If it is an S3 location, the DataSource will download it, we
308
+ # just need to extract the tar file.
305
309
host_dir = os .path .join (self .container_root , host )
306
310
os .makedirs (host_dir )
307
311
308
- if model_location .startswith ('s3' ):
309
- container_model_dir = os .path .join (self .container_root , host , 'model' )
310
- os .makedirs (container_model_dir )
312
+ model_data_source = sagemaker .local .data .get_data_source_instance (
313
+ model_location , self .sagemaker_session )
311
314
312
- parsed_uri = urlparse ( model_location )
313
- filename = os . path . basename ( parsed_uri . path )
314
- tar_location = os . path . join ( container_model_dir , filename )
315
- sagemaker . utils . download_file ( parsed_uri . netloc , parsed_uri . path , tar_location , self . sagemaker_session )
315
+ for filename in model_data_source . get_file_list ():
316
+ if tarfile . is_tarfile ( filename ):
317
+ with tarfile . open ( filename ) as tar :
318
+ tar . extractall ( path = model_data_source . get_root_dir () )
316
319
317
- if tarfile .is_tarfile (tar_location ):
318
- with tarfile .open (tar_location ) as tar :
319
- tar .extractall (path = container_model_dir )
320
- volumes .append (_Volume (container_model_dir , '/opt/ml/model' ))
321
- else :
322
- volumes .append (_Volume (model_location , '/opt/ml/model' ))
320
+ volumes .append (_Volume (model_data_source .get_root_dir (), '/opt/ml/model' ))
323
321
324
322
return volumes
325
323
@@ -368,7 +366,6 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
368
366
'networks' : {
369
367
'sagemaker-local' : {'name' : 'sagemaker-local' }
370
368
}
371
-
372
369
}
373
370
374
371
docker_compose_path = os .path .join (self .container_root , DOCKER_COMPOSE_FILENAME )
@@ -469,9 +466,15 @@ def _build_optml_volumes(self, host, subdirs):
469
466
470
467
return volumes
471
468
472
- def _cleanup (self ):
473
- # we don't need to cleanup anything at the moment
474
- pass
469
+ def _cleanup (self , dirs_to_delete = None ):
470
+ if dirs_to_delete :
471
+ for d in dirs_to_delete :
472
+ _delete_tree (d )
473
+
474
+ # Free the container config files.
475
+ for host in self .hosts :
476
+ container_config_path = os .path .join (self .container_root , host )
477
+ _delete_tree (container_config_path )
475
478
476
479
477
480
class _HostingContainer (Thread ):
@@ -610,7 +613,7 @@ def _aws_credentials(session):
610
613
'AWS_SECRET_ACCESS_KEY=%s' % (str (secret_key ))
611
614
]
612
615
elif not _aws_credentials_available_in_metadata_service ():
613
- logger .warn ("Using the short-lived AWS credentials found in session. They might expire while running." )
616
+ logger .warning ("Using the short-lived AWS credentials found in session. They might expire while running." )
614
617
return [
615
618
'AWS_ACCESS_KEY_ID=%s' % (str (access_key )),
616
619
'AWS_SECRET_ACCESS_KEY=%s' % (str (secret_key )),
0 commit comments