25
25
import subprocess
26
26
import sys
27
27
import tempfile
28
- from subprocess import Popen
28
+ from fcntl import fcntl , F_GETFL , F_SETFL
29
29
from six .moves .urllib .parse import urlparse
30
- from time import sleep
30
+ from threading import Thread
31
31
32
32
import yaml
33
33
@@ -91,42 +91,7 @@ def train(self, input_data_config, hyperparameters):
91
91
os .mkdir (shared_dir )
92
92
93
93
data_dir = self ._create_tmp_folder ()
94
- volumes = []
95
-
96
- # Set up the channels for the containers. For local data we will
97
- # mount the local directory to the container. For S3 Data we will download the S3 data
98
- # first.
99
- for channel in input_data_config :
100
- if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
101
- uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
102
- elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
103
- uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
104
- else :
105
- raise ValueError ('Need channel[\' DataSource\' ] to have [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
106
-
107
- parsed_uri = urlparse (uri )
108
- key = parsed_uri .path .lstrip ('/' )
109
-
110
- channel_name = channel ['ChannelName' ]
111
- channel_dir = os .path .join (data_dir , channel_name )
112
- os .mkdir (channel_dir )
113
-
114
- if parsed_uri .scheme == 's3' :
115
- bucket_name = parsed_uri .netloc
116
- self ._download_folder (bucket_name , key , channel_dir )
117
- elif parsed_uri .scheme == 'file' :
118
- path = parsed_uri .path
119
- volumes .append (_Volume (path , channel = channel_name ))
120
- else :
121
- raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
122
-
123
- # If the training script directory is a local directory, mount it to the container.
124
- training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
125
- parsed_uri = urlparse (training_dir )
126
- if parsed_uri .scheme == 'file' :
127
- volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
128
- # Also mount a directory that all the containers can access.
129
- volumes .append (_Volume (shared_dir , '/opt/ml/shared' ))
94
+ volumes = self ._prepare_training_volumes (data_dir , input_data_config , hyperparameters )
130
95
131
96
# Create the configuration files for each container that we will create
132
97
# Each container will map the additional local volumes (if any).
@@ -139,7 +104,15 @@ def train(self, input_data_config, hyperparameters):
139
104
compose_command = self ._compose ()
140
105
141
106
_ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
142
- _execute_and_stream_output (compose_command )
107
+ process = subprocess .Popen (compose_command , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
108
+
109
+ try :
110
+ _stream_output (process )
111
+ except RuntimeError as e :
112
+ # _stream_output() doesn't have the command line. We will handle the exception
113
+ # which contains the exit code and append the command line to it.
114
+ msg = "Failed to run: %s, %s" % (compose_command , e .message )
115
+ raise RuntimeError (msg )
143
116
144
117
s3_artifacts = self .retrieve_artifacts (compose_data )
145
118
@@ -196,7 +169,7 @@ def serve(self, primary_container):
196
169
additional_volumes = volumes )
197
170
compose_command = self ._compose ()
198
171
self .container = _HostingContainer (compose_command )
199
- self .container .up ()
172
+ self .container .start ()
200
173
201
174
def stop_serving (self ):
202
175
"""Stop the serving container.
@@ -205,6 +178,7 @@ def stop_serving(self):
205
178
"""
206
179
if self .container :
207
180
self .container .down ()
181
+ self .container .join ()
208
182
self ._cleanup ()
209
183
# for serving we can delete everything in the container root.
210
184
_delete_tree (self .container_root )
@@ -304,6 +278,47 @@ def _download_folder(self, bucket_name, prefix, target):
304
278
305
279
obj .download_file (file_path )
306
280
281
+ def _prepare_training_volumes (self , data_dir , input_data_config , hyperparameters ):
282
+ shared_dir = os .path .join (self .container_root , 'shared' )
283
+ volumes = []
284
+ # Set up the channels for the containers. For local data we will
285
+ # mount the local directory to the container. For S3 Data we will download the S3 data
286
+ # first.
287
+ for channel in input_data_config :
288
+ if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
289
+ uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
290
+ elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
291
+ uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
292
+ else :
293
+ raise ValueError ('Need channel[\' DataSource\' ] to have'
294
+ ' [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
295
+
296
+ parsed_uri = urlparse (uri )
297
+ key = parsed_uri .path .lstrip ('/' )
298
+
299
+ channel_name = channel ['ChannelName' ]
300
+ channel_dir = os .path .join (data_dir , channel_name )
301
+ os .mkdir (channel_dir )
302
+
303
+ if parsed_uri .scheme == 's3' :
304
+ bucket_name = parsed_uri .netloc
305
+ self ._download_folder (bucket_name , key , channel_dir )
306
+ elif parsed_uri .scheme == 'file' :
307
+ path = parsed_uri .path
308
+ volumes .append (_Volume (path , channel = channel_name ))
309
+ else :
310
+ raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
311
+
312
+ # If the training script directory is a local directory, mount it to the container.
313
+ training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
314
+ parsed_uri = urlparse (training_dir )
315
+ if parsed_uri .scheme == 'file' :
316
+ volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
317
+ # Also mount a directory that all the containers can access.
318
+ volumes .append (_Volume (shared_dir , '/opt/ml/shared' ))
319
+
320
+ return volumes
321
+
307
322
def _generate_compose_file (self , command , additional_volumes = None , additional_env_vars = None ):
308
323
"""Writes a config file describing a training/hosting environment.
309
324
@@ -452,15 +467,23 @@ def _cleanup(self):
452
467
pass
453
468
454
469
455
- class _HostingContainer (object ):
456
- def __init__ (self , command , startup_delay = 5 ):
470
+ class _HostingContainer (Thread ):
471
+ def __init__ (self , command ):
472
+ Thread .__init__ (self )
457
473
self .command = command
458
- self .startup_delay = startup_delay
459
474
self .process = None
460
475
461
- def up (self ):
462
- self .process = Popen (self .command )
463
- sleep (self .startup_delay )
476
+ def run (self ):
477
+ self .process = subprocess .Popen (self .command ,
478
+ stdout = subprocess .PIPE ,
479
+ stderr = subprocess .PIPE )
480
+ try :
481
+ _stream_output (self .process )
482
+ except RuntimeError as e :
483
+ # _stream_output() doesn't have the command line. We will handle the exception
484
+ # which contains the exit code and append the command line to it.
485
+ msg = "Failed to run: %s, %s" % (self .command , e .message )
486
+ raise RuntimeError (msg )
464
487
465
488
def down (self ):
466
489
self .process .terminate ()
@@ -495,26 +518,41 @@ def __init__(self, host_dir, container_dir=None, channel=None):
495
518
self .map = '{}:{}' .format (self .host_dir , self .container_dir )
496
519
497
520
498
- def _execute_and_stream_output (cmd ):
499
- """Execute a command and stream the output to stdout
521
+ def _stream_output (process ):
522
+ """Stream the output of a process to stdout
523
+
524
+ This function takes an existing process that will be polled for output. Both stdout and
525
+ stderr will be polled and both will be sent to sys.stdout.
500
526
501
527
Args:
502
- cmd(str or List): either a string or a List (in Popen Format) with the command to execute.
528
+ process(subprocess.Popen): a process that has been started with
529
+ stdout=PIPE and stderr=PIPE
503
530
504
- Returns (int): command exit code
531
+ Returns (int): process exit code
505
532
"""
506
- if isinstance (cmd , str ):
507
- cmd = shlex .split (cmd )
508
- process = subprocess .Popen (cmd , stdout = subprocess .PIPE )
509
533
exit_code = None
534
+
535
+ # Get the current flags for the stderr file descriptor
536
+ # And add the NONBLOCK flag to allow us to read even if there is no data.
537
+ # Since usually stderr will be empty unless there is an error.
538
+ flags = fcntl (process .stderr , F_GETFL ) # get current process.stderr flags
539
+ fcntl (process .stderr , F_SETFL , flags | os .O_NONBLOCK )
540
+
510
541
while exit_code is None :
511
542
stdout = process .stdout .readline ().decode ("utf-8" )
512
543
sys .stdout .write (stdout )
544
+ try :
545
+ stderr = process .stderr .readline ().decode ("utf-8" )
546
+ sys .stdout .write (stderr )
547
+ except IOError :
548
+ # If there is nothing to read on stderr we will get an IOError
549
+ # this is fine.
550
+ pass
513
551
514
552
exit_code = process .poll ()
515
553
516
554
if exit_code != 0 :
517
- raise Exception ( "Failed to run %s, exit code: %s" % ( "," . join ( cmd ), exit_code ) )
555
+ raise RuntimeError ( "Process exited with code: %s" % exit_code )
518
556
519
557
return exit_code
520
558
0 commit comments