@@ -169,13 +169,13 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi
169
169
raise NotImplementedError ()
170
170
171
171
@classmethod
172
- def attach (cls , training_job_name , sagemaker_session = None ):
172
+ def attach (cls , training_job_name , sagemaker_session = None , job_details = None ):
173
173
"""Attach to an existing training job.
174
174
175
175
Create an Estimator bound to an existing training job, each subclass is responsible to implement
176
- ``from_training_job ()`` as this method delegates the actual Estimator creation to it. After
177
- attaching, if the training job has a Complete status, it can be ``deploy()`` ed to create
178
- a SageMaker Endpoint and return a ``Predictor``.
176
+ ``_prepare_init_params_from_job_description ()`` as this method delegates the actual conversion of a training
177
+ job description to the arguments that the class constructor expects. After attaching, if the training job has a
178
+ Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
179
179
180
180
If the training job is in progress, attach will block and display log messages
181
181
from the training job, until the training job completes.
@@ -198,13 +198,10 @@ def attach(cls, training_job_name, sagemaker_session=None):
198
198
"""
199
199
sagemaker_session = sagemaker_session or Session ()
200
200
201
- if training_job_name :
202
- job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
203
- init_params , hp , image = cls ._prepare_estimator_params_from_job_description (job_details )
204
- else :
205
- raise ValueError ('must specify training_job name' )
201
+ job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
202
+ init_params = cls ._prepare_init_params_from_job_description (job_details )
206
203
207
- estimator = cls . _from_training_job ( init_params , hp , image , sagemaker_session )
204
+ estimator = cls ( sagemaker_session = sagemaker_session , ** init_params )
208
205
estimator .latest_training_job = _TrainingJob (sagemaker_session = sagemaker_session ,
209
206
training_job_name = init_params ['base_job_name' ])
210
207
estimator .latest_training_job .wait ()
@@ -257,21 +254,33 @@ def create_model(self, **kwargs):
257
254
"""
258
255
pass
259
256
260
- @staticmethod
261
- def _prepare_estimator_params_from_job_description ( job_details ):
262
- estimator_params = dict ()
257
+ @classmethod
258
+ def _prepare_init_params_from_job_description ( cls , job_details ):
259
+ """Convert the job description to init params that can be handled by the class constructor
263
260
264
- estimator_params ['role' ] = job_details ['RoleArn' ]
265
- estimator_params ['train_instance_count' ] = job_details ['ResourceConfig' ]['InstanceCount' ]
266
- estimator_params ['train_instance_type' ] = job_details ['ResourceConfig' ]['InstanceType' ]
267
- estimator_params ['train_volume_size' ] = job_details ['ResourceConfig' ]['VolumeSizeInGB' ]
268
- estimator_params ['train_max_run' ] = job_details ['StoppingCondition' ]['MaxRuntimeInSeconds' ]
269
- estimator_params ['input_mode' ] = job_details ['AlgorithmSpecification' ]['TrainingInputMode' ]
270
- estimator_params ['base_job_name' ] = job_details ['TrainingJobName' ]
271
- estimator_params ['output_path' ] = job_details ['OutputDataConfig' ]['S3OutputPath' ]
272
- estimator_params ['output_kms_key' ] = job_details ['OutputDataConfig' ]['KmsKeyId' ]
261
+ Args:
262
+ job_details: the returned job details from a describe_training_job API call.
273
263
274
- return estimator_params , job_details ['HyperParameters' ], job_details ['AlgorithmSpecification' ]['TrainingImage' ]
264
+ Returns:
265
+ dictionary: The transformed init_params
266
+
267
+ """
268
+ init_params = dict ()
269
+
270
+ init_params ['role' ] = job_details ['RoleArn' ]
271
+ init_params ['train_instance_count' ] = job_details ['ResourceConfig' ]['InstanceCount' ]
272
+ init_params ['train_instance_type' ] = job_details ['ResourceConfig' ]['InstanceType' ]
273
+ init_params ['train_volume_size' ] = job_details ['ResourceConfig' ]['VolumeSizeInGB' ]
274
+ init_params ['train_max_run' ] = job_details ['StoppingCondition' ]['MaxRuntimeInSeconds' ]
275
+ init_params ['input_mode' ] = job_details ['AlgorithmSpecification' ]['TrainingInputMode' ]
276
+ init_params ['base_job_name' ] = job_details ['TrainingJobName' ]
277
+ init_params ['output_path' ] = job_details ['OutputDataConfig' ]['S3OutputPath' ]
278
+ init_params ['output_kms_key' ] = job_details ['OutputDataConfig' ]['KmsKeyId' ]
279
+
280
+ init_params ['hyperparameters' ] = job_details ['HyperParameters' ]
281
+ init_params ['image' ] = job_details ['AlgorithmSpecification' ]['TrainingImage' ]
282
+
283
+ return init_params
275
284
276
285
def delete_endpoint (self ):
277
286
"""Delete an Amazon SageMaker ``Endpoint``.
@@ -388,7 +397,8 @@ class Estimator(EstimatorBase):
388
397
389
398
def __init__ (self , image_name , role , train_instance_count , train_instance_type ,
390
399
train_volume_size = 30 , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
391
- output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ):
400
+ output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ,
401
+ hyperparameters = None ):
392
402
"""Initialize an ``Estimator`` instance.
393
403
394
404
Args:
@@ -420,9 +430,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
420
430
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
421
431
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
422
432
using the default AWS configuration chain.
433
+ hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
423
434
"""
424
435
self .image_name = image_name
425
- self .hyperparam_dict = {}
436
+ self .hyperparam_dict = hyperparameters . copy () if hyperparameters else {}
426
437
super (Estimator , self ).__init__ (role , train_instance_count , train_instance_type ,
427
438
train_volume_size , train_max_run , input_mode ,
428
439
output_path , output_kms_key , base_job_name , sagemaker_session )
@@ -478,23 +489,20 @@ def predict_wrapper(endpoint, session):
478
489
predictor_cls = predictor_cls , ** kwargs )
479
490
480
491
@classmethod
481
- def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
482
- """Create an Estimator from existing training job data.
492
+ def _prepare_init_params_from_job_description (cls , job_details ):
493
+ """Convert the job description to init params that can be handled by the class constructor
483
494
484
495
Args:
485
- init_params (dict): The init_params the training job was created with.
486
- hyperparameters (dict): The hyperparameters the training job was created with.
487
- image (str): Container image (if any) the training job was created with
488
- sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
496
+ job_details: the returned job details from a describe_training_job API call.
489
497
490
- Returns: An instance of the calling Estimator Class.
498
+ Returns:
499
+ dictionary: The transformed init_params
491
500
492
501
"""
502
+ init_params = super (Estimator , cls )._prepare_init_params_from_job_description (job_details )
493
503
494
- estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
495
- cls .set_hyperparameters (** hyperparameters )
496
-
497
- return estimator
504
+ init_params ['image_name' ] = init_params .pop ('image' )
505
+ return init_params
498
506
499
507
500
508
class Framework (EstimatorBase ):
@@ -602,35 +610,58 @@ def hyperparameters(self):
602
610
return self ._json_encode_hyperparameters (self ._hyperparameters )
603
611
604
612
@classmethod
605
- def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
606
- """Create an Estimator from existing training job data.
613
+ def _prepare_init_params_from_job_description (cls , job_details ):
614
+ """Convert the job description to init params that can be handled by the class constructor
607
615
608
616
Args:
609
- init_params (dict): The init_params the training job was created with.
610
- hyperparameters (dict): The hyperparameters the training job was created with.
611
- image (str): Container image (if any) the training job was created with
612
- sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
617
+ job_details: the returned job details from a describe_training_job API call.
613
618
614
- Returns: An instance of the calling Estimator Class.
619
+ Returns:
620
+ dictionary: The transformed init_params
615
621
616
622
"""
623
+ init_params = super (Framework , cls )._prepare_init_params_from_job_description (job_details )
617
624
618
- # parameters for framework classes
619
- framework_init_params = dict ()
620
- framework_init_params ['entry_point' ] = json .loads (hyperparameters .get (SCRIPT_PARAM_NAME ))
621
- framework_init_params ['source_dir' ] = json .loads (hyperparameters .get (DIR_PARAM_NAME ))
622
- framework_init_params ['enable_cloudwatch_metrics' ] = json .loads (
623
- hyperparameters .get (CLOUDWATCH_METRICS_PARAM_NAME ))
624
- framework_init_params ['container_log_level' ] = json .loads (
625
- hyperparameters .get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
625
+ init_params ['entry_point' ] = json .loads (init_params ['hyperparameters' ].get (SCRIPT_PARAM_NAME ))
626
+ init_params ['source_dir' ] = json .loads (init_params ['hyperparameters' ].get (DIR_PARAM_NAME ))
627
+ init_params ['enable_cloudwatch_metrics' ] = json .loads (
628
+ init_params ['hyperparameters' ].get (CLOUDWATCH_METRICS_PARAM_NAME ))
629
+ init_params ['container_log_level' ] = json .loads (
630
+ init_params ['hyperparameters' ].get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
626
631
627
- # drop json and remove other SageMaker specific additions
628
- deserialized_hps = {entry : json .loads (hyperparameters [entry ]) for entry in hyperparameters }
629
- framework_init_params ['hyperparameters' ] = deserialized_hps
632
+ init_params ['hyperparameters' ] = {k : json .loads (v ) for k , v in init_params ['hyperparameters' ].items ()}
630
633
631
- init_params . update ( framework_init_params )
634
+ return init_params
632
635
633
- estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
636
+ @classmethod
637
+ def attach (cls , training_job_name , sagemaker_session = None ):
638
+ """Attach to an existing training job.
639
+
640
+ Create an Estimator bound to an existing training job, each subclass is responsible to implement
641
+ ``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
642
+ job description to the arguments that the class constructor expects. After attaching, if the training job has a
643
+ Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.
644
+
645
+ If the training job is in progress, attach will block and display log messages
646
+ from the training job, until the training job completes.
647
+
648
+ Args:
649
+ training_job_name (str): The name of the training job to attach to.
650
+ sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
651
+ Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
652
+ using the default AWS configuration chain.
653
+
654
+ Examples:
655
+ >>> my_estimator.fit(wait=False)
656
+ >>> training_job_name = my_estimator.latest_training_job.name
657
+ Later on:
658
+ >>> attached_estimator = Estimator.attach(training_job_name)
659
+ >>> attached_estimator.deploy()
660
+
661
+ Returns:
662
+ Instance of the calling ``Estimator`` Class with the attached training job.
663
+ """
664
+ estimator = super (Framework , cls ).attach (training_job_name , sagemaker_session )
634
665
estimator .uploaded_code = UploadedCode (estimator .source_dir , estimator .entry_point )
635
666
return estimator
636
667
0 commit comments