Skip to content

Commit 8d96cec

Browse files
author
Ignacio Quintero
committed
Refactor attach to remove _from_training_job()
_prepare_init_params_from_job_description() is now a classmethod instead of being a static method. Each class is responsible to implement their specific logic to convert a training job description into arguments that can be passed to its own __init__()
1 parent 24ae51e commit 8d96cec

File tree

8 files changed

+132
-111
lines changed

8 files changed

+132
-111
lines changed

src/sagemaker/amazon/amazon_estimator.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,29 @@ def data_location(self, data_location):
6565
self._data_location = data_location
6666

6767
@classmethod
68-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
69-
"""Create an Estimator from existing training job data.
68+
def _prepare_init_params_from_job_description(cls, job_details):
69+
"""Convert the job description to init params that can be handled by the class constructor
7070
7171
Args:
72-
init_params (dict): The init_params the training job was created with.
73-
hyperparameters (dict): The hyperparameters the training job was created with.
74-
image (str): Container image (if any) the training job was created with
75-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
72+
job_details: the returned job details from a describe_training_job API call.
7673
77-
Returns: An instance of the calling Estimator Class.
74+
Returns:
75+
dictionary: The transformed init_params
7876
7977
"""
78+
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details)
8079

8180
# The hyperparam names may not be the same as the class attribute that holds them,
8281
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
8382
# and pass the correct name to the constructor.
8483
for attribute, value in cls.__dict__.items():
8584
if isinstance(value, hp):
86-
if value.name in hyperparameters:
87-
init_params[attribute] = hyperparameters[value.name]
85+
if value.name in init_params['hyperparameters']:
86+
init_params[attribute] = init_params['hyperparameters'][value.name]
8887

89-
return cls(sagemaker_session=sagemaker_session, **init_params)
88+
del init_params['hyperparameters']
89+
del init_params['image']
90+
return init_params
9091

9192
def fit(self, records, mini_batch_size=None, **kwargs):
9293
"""Fit this Estimator on serialized Record objects, stored in S3.

src/sagemaker/estimator.py

+87-56
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi
169169
raise NotImplementedError()
170170

171171
@classmethod
172-
def attach(cls, training_job_name, sagemaker_session=None):
172+
def attach(cls, training_job_name, sagemaker_session=None, job_details=None):
173173
"""Attach to an existing training job.
174174
175175
Create an Estimator bound to an existing training job, each subclass is responsible to implement
176-
``from_training_job()`` as this method delegates the actual Estimator creation to it. After
177-
attaching, if the training job has a Complete status, it can be ``deploy()`` ed to create
178-
a SageMaker Endpoint and return a ``Predictor``.
176+
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
177+
job description to the arguments that the class constructor expects. After attaching, if the training job has a
178+
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
179179
180180
If the training job is in progress, attach will block and display log messages
181181
from the training job, until the training job completes.
@@ -198,13 +198,10 @@ def attach(cls, training_job_name, sagemaker_session=None):
198198
"""
199199
sagemaker_session = sagemaker_session or Session()
200200

201-
if training_job_name:
202-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
203-
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
204-
else:
205-
raise ValueError('must specify training_job name')
201+
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
202+
init_params = cls._prepare_init_params_from_job_description(job_details)
206203

207-
estimator = cls._from_training_job(init_params, hp, image, sagemaker_session)
204+
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
208205
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
209206
training_job_name=init_params['base_job_name'])
210207
estimator.latest_training_job.wait()
@@ -257,21 +254,33 @@ def create_model(self, **kwargs):
257254
"""
258255
pass
259256

260-
@staticmethod
261-
def _prepare_estimator_params_from_job_description(job_details):
262-
estimator_params = dict()
257+
@classmethod
258+
def _prepare_init_params_from_job_description(cls, job_details):
259+
"""Convert the job description to init params that can be handled by the class constructor
263260
264-
estimator_params['role'] = job_details['RoleArn']
265-
estimator_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
266-
estimator_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
267-
estimator_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
268-
estimator_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
269-
estimator_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
270-
estimator_params['base_job_name'] = job_details['TrainingJobName']
271-
estimator_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
272-
estimator_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
261+
Args:
262+
job_details: the returned job details from a describe_training_job API call.
273263
274-
return estimator_params, job_details['HyperParameters'], job_details['AlgorithmSpecification']['TrainingImage']
264+
Returns:
265+
dictionary: The transformed init_params
266+
267+
"""
268+
init_params = dict()
269+
270+
init_params['role'] = job_details['RoleArn']
271+
init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
272+
init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
273+
init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
274+
init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
275+
init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
276+
init_params['base_job_name'] = job_details['TrainingJobName']
277+
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
278+
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
279+
280+
init_params['hyperparameters'] = job_details['HyperParameters']
281+
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']
282+
283+
return init_params
275284

276285
def delete_endpoint(self):
277286
"""Delete an Amazon SageMaker ``Endpoint``.
@@ -388,7 +397,8 @@ class Estimator(EstimatorBase):
388397

389398
def __init__(self, image_name, role, train_instance_count, train_instance_type,
390399
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
391-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
400+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
401+
hyperparameters=None):
392402
"""Initialize an ``Estimator`` instance.
393403
394404
Args:
@@ -420,9 +430,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
420430
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
421431
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
422432
using the default AWS configuration chain.
433+
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
423434
"""
424435
self.image_name = image_name
425-
self.hyperparam_dict = {}
436+
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
426437
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
427438
train_volume_size, train_max_run, input_mode,
428439
output_path, output_kms_key, base_job_name, sagemaker_session)
@@ -478,23 +489,20 @@ def predict_wrapper(endpoint, session):
478489
predictor_cls=predictor_cls, **kwargs)
479490

480491
@classmethod
481-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
482-
"""Create an Estimator from existing training job data.
492+
def _prepare_init_params_from_job_description(cls, job_details):
493+
"""Convert the job description to init params that can be handled by the class constructor
483494
484495
Args:
485-
init_params (dict): The init_params the training job was created with.
486-
hyperparameters (dict): The hyperparameters the training job was created with.
487-
image (str): Container image (if any) the training job was created with
488-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
496+
job_details: the returned job details from a describe_training_job API call.
489497
490-
Returns: An instance of the calling Estimator Class.
498+
Returns:
499+
dictionary: The transformed init_params
491500
492501
"""
502+
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details)
493503

494-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
495-
cls.set_hyperparameters(**hyperparameters)
496-
497-
return estimator
504+
init_params['image_name'] = init_params.pop('image')
505+
return init_params
498506

499507

500508
class Framework(EstimatorBase):
@@ -602,35 +610,58 @@ def hyperparameters(self):
602610
return self._json_encode_hyperparameters(self._hyperparameters)
603611

604612
@classmethod
605-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
606-
"""Create an Estimator from existing training job data.
613+
def _prepare_init_params_from_job_description(cls, job_details):
614+
"""Convert the job description to init params that can be handled by the class constructor
607615
608616
Args:
609-
init_params (dict): The init_params the training job was created with.
610-
hyperparameters (dict): The hyperparameters the training job was created with.
611-
image (str): Container image (if any) the training job was created with
612-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
617+
job_details: the returned job details from a describe_training_job API call.
613618
614-
Returns: An instance of the calling Estimator Class.
619+
Returns:
620+
dictionary: The transformed init_params
615621
616622
"""
623+
init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details)
617624

618-
# parameters for framework classes
619-
framework_init_params = dict()
620-
framework_init_params['entry_point'] = json.loads(hyperparameters.get(SCRIPT_PARAM_NAME))
621-
framework_init_params['source_dir'] = json.loads(hyperparameters.get(DIR_PARAM_NAME))
622-
framework_init_params['enable_cloudwatch_metrics'] = json.loads(
623-
hyperparameters.get(CLOUDWATCH_METRICS_PARAM_NAME))
624-
framework_init_params['container_log_level'] = json.loads(
625-
hyperparameters.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
625+
init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME))
626+
init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME))
627+
init_params['enable_cloudwatch_metrics'] = json.loads(
628+
init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME))
629+
init_params['container_log_level'] = json.loads(
630+
init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME))
626631

627-
# drop json and remove other SageMaker specific additions
628-
deserialized_hps = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
629-
framework_init_params['hyperparameters'] = deserialized_hps
632+
init_params['hyperparameters'] = {k: json.loads(v) for k, v in init_params['hyperparameters'].items()}
630633

631-
init_params.update(framework_init_params)
634+
return init_params
632635

633-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
636+
@classmethod
637+
def attach(cls, training_job_name, sagemaker_session=None):
638+
"""Attach to an existing training job.
639+
640+
Create an Estimator bound to an existing training job, each subclass is responsible to implement
641+
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
642+
job description to the arguments that the class constructor expects. After attaching, if the training job has a
643+
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
644+
645+
If the training job is in progress, attach will block and display log messages
646+
from the training job, until the training job completes.
647+
648+
Args:
649+
training_job_name (str): The name of the training job to attach to.
650+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
651+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
652+
using the default AWS configuration chain.
653+
654+
Examples:
655+
>>> my_estimator.fit(wait=False)
656+
>>> training_job_name = my_estimator.latest_training_job.name
657+
Later on:
658+
>>> attached_estimator = Estimator.attach(training_job_name)
659+
>>> attached_estimator.deploy()
660+
661+
Returns:
662+
Instance of the calling ``Estimator`` Class with the attached training job.
663+
"""
664+
estimator = super(Framework, cls).attach(training_job_name, sagemaker_session)
634665
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point)
635666
return estimator
636667

src/sagemaker/mxnet/estimator.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,23 @@ def create_model(self, model_server_workers=None):
8282
sagemaker_session=self.sagemaker_session)
8383

8484
@classmethod
85-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
86-
"""Create an Estimator from existing training job data.
85+
def _prepare_init_params_from_job_description(cls, job_details):
86+
"""Convert the job description to init params that can be handled by the class constructor
8787
8888
Args:
89-
init_params (dict): The init_params the training job was created with.
90-
hyperparameters (dict): The hyperparameters the training job was created with.
91-
image (str): Container image (if any) the training job was created with
92-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
89+
job_details: the returned job details from a describe_training_job API call.
9390
94-
Returns: An instance of the calling Estimator Class.
91+
Returns:
92+
dictionary: The transformed init_params
9593
9694
"""
95+
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details)
96+
framework, py_version = framework_name_from_image(init_params.pop('image'))
9797

98-
framework, py_version = framework_name_from_image(image)
99-
init_params.update({'py_version': py_version})
100-
98+
init_params['py_version'] = py_version
10199
training_job_name = init_params['base_job_name']
100+
102101
if framework != cls.__framework_name__:
103102
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
104103

105-
return super(MXNet, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)
104+
return init_params

src/sagemaker/tensorflow/estimator.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
import logging
14+
import os
1415
import subprocess
1516
import tempfile
1617
import threading
1718

18-
import os
19-
2019
import sagemaker.tensorflow
2120
from sagemaker.estimator import Framework
2221
from sagemaker.fw_utils import create_image_uri, framework_name_from_image
@@ -168,31 +167,32 @@ def fit_super():
168167
fit_super()
169168

170169
@classmethod
171-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
172-
"""Create an Estimator from existing training job data.
170+
def _prepare_init_params_from_job_description(cls, job_details):
171+
"""Convert the job description to init params that can be handled by the class constructor
173172
174173
Args:
175-
init_params (dict): The init_params the training job was created with.
176-
hyperparameters (dict): The hyperparameters the training job was created with.
177-
image (str): Container image (if any) the training job was created with
178-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
174+
job_details: the returned job details from a describe_training_job API call.
179175
180-
Returns: An instance of the calling Estimator Class.
176+
Returns:
177+
dictionary: The transformed init_params
181178
182179
"""
180+
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details)
183181

184-
updated_params = cls._update_init_params(hyperparameters,
185-
['checkpoint_path', 'training_steps', 'evaluation_steps'])
186-
init_params.update(updated_params)
182+
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
183+
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
184+
value = init_params['hyperparameters'].pop(argument, None)
185+
if value is not None:
186+
init_params[argument] = value
187187

188-
framework, py_version = framework_name_from_image(image)
189-
init_params.update({'py_version': py_version})
190-
training_job_name = init_params['base_job_name']
188+
framework, py_version = framework_name_from_image(init_params.pop('image'))
189+
init_params['py_version'] = py_version
191190

191+
training_job_name = init_params['base_job_name']
192192
if framework != cls.__framework_name__:
193193
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
194194

195-
return super(TensorFlow, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)
195+
return init_params
196196

197197
def train_image(self):
198198
"""Return the Docker image to use for training.

tests/integ/test_byo_estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,5 @@ def test_async_byo_estimator():
154154
assert len(result['predictions']) == 10
155155
for prediction in result['predictions']:
156156
assert prediction['score'] is not None
157+
158+
assert estimator.train_image() == image_name

0 commit comments

Comments
 (0)