Skip to content

Commit e1d79d5

Browse files
authored
Add support for async fit() (aws#59)
when calling fit(wait=False) it will return immediately. The training job will carry on even if the process exits. by using attach() the estimator can be retrieved by providing the training job name. _prepare_init_params_from_job_description() is now a classmethod instead of being a static method. Each class is responsible to implement their specific logic to convert a training job description into arguments that can be passed to its own __init__()
1 parent 354ded3 commit e1d79d5

16 files changed

+556
-142
lines changed

README.rst

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ SageMaker Python SDK provides several high-level abstractions for working with A
9797
- **Estimators**: Encapsulate training on SageMaker. Can be ``fit()`` to run training, then the resulting model ``deploy()`` ed to a SageMaker Endpoint.
9898
- **Models**: Encapsulate built ML models. Can be ``deploy()`` ed to a SageMaker Endpoint.
9999
- **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker Endpoint.
100-
- **Session**: Provides a collection of convience methods for working with SageMaker resources.
100+
- **Session**: Provides a collection of convenience methods for working with SageMaker resources.
101101

102102
Estimator and Model implementations for MXNet, TensorFlow, and Amazon ML algorithms are included. There's also an Estimator that runs SageMaker compatible custom Docker containers, allowing you to run your own ML algorithms via SageMaker Python SDK.
103103

@@ -1150,6 +1150,7 @@ Optional arguments
11501150
11511151
- ``wait (bool)``: Defaults to True, whether to block and wait for the
11521152
training script to complete before returning.
1153+
If set to False, it will return immediately, and can later be attached to.
11531154
- ``logs (bool)``: Defaults to True, whether to show logs produced by training
11541155
job in the Python session. Only meaningful when wait is True.
11551156
- ``run_tensorboard_locally (bool)``: Defaults to False. Executes TensorBoard in a different
@@ -1178,9 +1179,25 @@ the ``TensorFlow`` estimator parameter ``training_steps`` is finished or when th
11781179
job execution time reaches the ``TensorFlow`` estimator parameter ``train_max_run``.
11791180
11801181
When the training job finishes, a `TensorFlow serving <https://www.tensorflow.org/serving/serving_basic>`_
1181-
with the result of the training is generated and saved to the S3 location define by
1182+
with the result of the training is generated and saved to the S3 location defined by
11821183
the ``TensorFlow`` estimator parameter ``output_path``.
11831184
1185+
1186+
If the ``wait=False`` flag is passed to ``fit``, then it will return immediately. The training job will continue running
1187+
asynchronously. At a later time, a Tensorflow Estimator can be obtained by attaching to the existing training job. If
1188+
the training job is not finished it will start showing the standard output of training and wait until it completes.
1189+
After attaching, the estimator can be deployed as usual.
1190+
1191+
.. code:: python
1192+
1193+
tf_estimator.fit(your_input_data, wait=False)
1194+
training_job_name = tf_estimator.latest_training_job.name
1195+
1196+
# after some time, or in a separate python notebook, we can attach to it again.
1197+
1198+
tf_estimator = TensorFlow.attach(training_job_name=training_job_name)
1199+
1200+
11841201
The evaluation process
11851202
""""""""""""""""""""""
11861203
@@ -1244,6 +1261,8 @@ You can access TensorBoard locally at http://localhost:6006 or using your SakeMa
12441261
`https*workspace_base_url*proxy/6006/ <proxy/6006/>`_ (TensorBoard will not work if you forget to put the slash,
12451262
'/', in end of the url). If TensorBoard started on a different port, adjust these URLs to match.
12461263
1264+
Note that TensorBoard is not supported when passing wait=False to ``fit``.
1265+
12471266
12481267
Deploying TensorFlow Serving models
12491268
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,31 @@ def data_location(self, data_location):
6565
data_location = data_location + '/'
6666
self._data_location = data_location
6767

68+
@classmethod
69+
def _prepare_init_params_from_job_description(cls, job_details):
70+
"""Convert the job description to init params that can be handled by the class constructor
71+
72+
Args:
73+
job_details: the returned job details from a describe_training_job API call.
74+
75+
Returns:
76+
dictionary: The transformed init_params
77+
78+
"""
79+
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details)
80+
81+
# The hyperparam names may not be the same as the class attribute that holds them,
82+
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
83+
# and pass the correct name to the constructor.
84+
for attribute, value in cls.__dict__.items():
85+
if isinstance(value, hp):
86+
if value.name in init_params['hyperparameters']:
87+
init_params[attribute] = init_params['hyperparameters'][value.name]
88+
89+
del init_params['hyperparameters']
90+
del init_params['image']
91+
return init_params
92+
6893
def fit(self, records, mini_batch_size=None, **kwargs):
6994
"""Fit this Estimator on serialized Record objects, stored in S3.
7095

src/sagemaker/estimator.py

Lines changed: 137 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,60 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
152152
self.latest_training_job = _TrainingJob.start_new(self, inputs)
153153
if wait:
154154
self.latest_training_job.wait(logs=logs)
155-
else:
156-
raise NotImplemented('Asynchronous fit not available')
155+
156+
@classmethod
157+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
158+
"""Create an Estimator from existing training job data.
159+
160+
Args:
161+
init_params (dict): The init_params the training job was created with.
162+
hyperparameters (dict): The hyperparameters the training job was created with.
163+
image (str): Container image (if any) the training job was created with
164+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
165+
166+
Returns: An instance of the calling Estimator Class.
167+
168+
"""
169+
raise NotImplementedError()
170+
171+
@classmethod
172+
def attach(cls, training_job_name, sagemaker_session=None, job_details=None):
173+
"""Attach to an existing training job.
174+
175+
Create an Estimator bound to an existing training job, each subclass is responsible to implement
176+
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
177+
job description to the arguments that the class constructor expects. After attaching, if the training job has a
178+
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
179+
180+
If the training job is in progress, attach will block and display log messages
181+
from the training job, until the training job completes.
182+
183+
Args:
184+
training_job_name (str): The name of the training job to attach to.
185+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
186+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
187+
using the default AWS configuration chain.
188+
189+
Examples:
190+
>>> my_estimator.fit(wait=False)
191+
>>> training_job_name = my_estimator.latest_training_job.name
192+
Later on:
193+
>>> attached_estimator = Estimator.attach(training_job_name)
194+
>>> attached_estimator.deploy()
195+
196+
Returns:
197+
Instance of the calling ``Estimator`` Class with the attached training job.
198+
"""
199+
sagemaker_session = sagemaker_session or Session()
200+
201+
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
202+
init_params = cls._prepare_init_params_from_job_description(job_details)
203+
204+
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
205+
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
206+
training_job_name=init_params['base_job_name'])
207+
estimator.latest_training_job.wait()
208+
return estimator
157209

158210
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
159211
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
@@ -202,21 +254,33 @@ def create_model(self, **kwargs):
202254
"""
203255
pass
204256

205-
@staticmethod
206-
def _prepare_estimator_params_from_job_description(job_details):
207-
estimator_params = dict()
257+
@classmethod
258+
def _prepare_init_params_from_job_description(cls, job_details):
259+
"""Convert the job description to init params that can be handled by the class constructor
260+
261+
Args:
262+
job_details: the returned job details from a describe_training_job API call.
263+
264+
Returns:
265+
dictionary: The transformed init_params
266+
267+
"""
268+
init_params = dict()
208269

209-
estimator_params['role'] = job_details['RoleArn']
210-
estimator_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
211-
estimator_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
212-
estimator_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
213-
estimator_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
214-
estimator_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
215-
estimator_params['base_job_name'] = job_details['TrainingJobName']
216-
estimator_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
217-
estimator_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
270+
init_params['role'] = job_details['RoleArn']
271+
init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
272+
init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
273+
init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
274+
init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
275+
init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
276+
init_params['base_job_name'] = job_details['TrainingJobName']
277+
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
278+
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
218279

219-
return estimator_params, job_details['HyperParameters'], job_details['AlgorithmSpecification']['TrainingImage']
280+
init_params['hyperparameters'] = job_details['HyperParameters']
281+
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']
282+
283+
return init_params
220284

221285
def delete_endpoint(self):
222286
"""Delete an Amazon SageMaker ``Endpoint``.
@@ -333,7 +397,8 @@ class Estimator(EstimatorBase):
333397

334398
def __init__(self, image_name, role, train_instance_count, train_instance_type,
335399
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
336-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
400+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
401+
hyperparameters=None):
337402
"""Initialize an ``Estimator`` instance.
338403
339404
Args:
@@ -365,9 +430,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
365430
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
366431
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
367432
using the default AWS configuration chain.
433+
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
368434
"""
369435
self.image_name = image_name
370-
self.hyperparam_dict = {}
436+
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
371437
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
372438
train_volume_size, train_max_run, input_mode,
373439
output_path, output_kms_key, base_job_name, sagemaker_session)
@@ -422,6 +488,22 @@ def predict_wrapper(endpoint, session):
422488
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
423489
predictor_cls=predictor_cls, **kwargs)
424490

491+
@classmethod
492+
def _prepare_init_params_from_job_description(cls, job_details):
493+
"""Convert the job description to init params that can be handled by the class constructor
494+
495+
Args:
496+
job_details: the returned job details from a describe_training_job API call.
497+
498+
Returns:
499+
dictionary: The transformed init_params
500+
501+
"""
502+
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details)
503+
504+
init_params['image_name'] = init_params.pop('image')
505+
return init_params
506+
425507

426508
class Framework(EstimatorBase):
427509
"""Base class that cannot be instantiated directly.
@@ -528,12 +610,37 @@ def hyperparameters(self):
528610
return self._json_encode_hyperparameters(self._hyperparameters)
529611

530612
@classmethod
531-
def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
613+
def _prepare_init_params_from_job_description(cls, job_details):
614+
"""Convert the job description to init params that can be handled by the class constructor
615+
616+
Args:
617+
job_details: the returned job details from a describe_training_job API call.
618+
619+
Returns:
620+
dictionary: The transformed init_params
621+
622+
"""
623+
init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details)
624+
625+
init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME))
626+
init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME))
627+
init_params['enable_cloudwatch_metrics'] = json.loads(
628+
init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME))
629+
init_params['container_log_level'] = json.loads(
630+
init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME))
631+
632+
init_params['hyperparameters'] = {k: json.loads(v) for k, v in init_params['hyperparameters'].items()}
633+
634+
return init_params
635+
636+
@classmethod
637+
def attach(cls, training_job_name, sagemaker_session=None):
532638
"""Attach to an existing training job.
533639
534-
Create an Estimator bound to an existing training job. After attaching, if
535-
the training job has a Complete status, it can be ``deploy()`` ed to create
536-
a SageMaker Endpoint and return a ``Predictor``.
640+
Create an Estimator bound to an existing training job, each subclass is responsible to implement
641+
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
642+
job description to the arguments that the class constructor expects. After attaching, if the training job has a
643+
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
537644
538645
If the training job is in progress, attach will block and display log messages
539646
from the training job, until the training job completes.
@@ -543,41 +650,18 @@ def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
543650
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
544651
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
545652
using the default AWS configuration chain.
546-
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
653+
654+
Examples:
655+
>>> my_estimator.fit(wait=False)
656+
>>> training_job_name = my_estimator.latest_training_job.name
657+
Later on:
658+
>>> attached_estimator = Estimator.attach(training_job_name)
659+
>>> attached_estimator.deploy()
547660
548661
Returns:
549-
sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
662+
Instance of the calling ``Estimator`` Class with the attached training job.
550663
"""
551-
sagemaker_session = sagemaker_session or Session()
552-
553-
if training_job_name is not None:
554-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
555-
init_params, hp, _ = cls._prepare_estimator_params_from_job_description(job_details)
556-
557-
else:
558-
# this case is only valid when called from inheriting class and then the class must declare framework
559-
if not hasattr(cls, '__framework_name__'):
560-
raise ValueError('must specify training_job name')
561-
init_params = dict(kwargs)
562-
hp = init_params.pop('hyperparameters')
563-
564-
# parameters for framework classes
565-
framework_init_params = dict()
566-
framework_init_params['entry_point'] = json.loads(hp.get(SCRIPT_PARAM_NAME))
567-
framework_init_params['source_dir'] = json.loads(hp.get(DIR_PARAM_NAME))
568-
framework_init_params['enable_cloudwatch_metrics'] = json.loads(hp.get(CLOUDWATCH_METRICS_PARAM_NAME))
569-
framework_init_params['container_log_level'] = json.loads(hp.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
570-
571-
# drop json and remove other SageMaker specific additions
572-
hyperparameters = {entry: json.loads(hp[entry]) for entry in hp}
573-
framework_init_params['hyperparameters'] = hyperparameters
574-
575-
init_params.update(framework_init_params)
576-
577-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
578-
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
579-
training_job_name=init_params['base_job_name'])
580-
estimator.latest_training_job.wait()
664+
estimator = super(Framework, cls).attach(training_job_name, sagemaker_session)
581665
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point)
582666
return estimator
583667

0 commit comments

Comments
 (0)