@@ -152,8 +152,47 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
152
152
self .latest_training_job = _TrainingJob .start_new (self , inputs )
153
153
if wait :
154
154
self .latest_training_job .wait (logs = logs )
155
+
156
+
157
+ @classmethod
158
+ def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
159
+ raise NotImplementedError ()
160
+
161
+ @classmethod
162
+ def attach (cls , training_job_name , sagemaker_session = None , ** kwargs ):
163
+ """Attach to an existing training job.
164
+
165
+ Create an Estimator bound to an existing training job. After attaching, if
166
+ the training job has a Complete status, it can be ``deploy()`` ed to create
167
+ a SageMaker Endpoint and return a ``Predictor``.
168
+
169
+ If the training job is in progress, attach will block and display log messages
170
+ from the training job, until the training job completes.
171
+
172
+ Args:
173
+ training_job_name (str): The name of the training job to attach to.
174
+ sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
175
+ Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
176
+ using the default AWS configuration chain.
177
+ **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
178
+
179
+ Returns:
180
+ sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
181
+ """
182
+ sagemaker_session = sagemaker_session or Session ()
183
+
184
+ if training_job_name is not None :
185
+ job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
186
+ init_params , hp , image = cls ._prepare_estimator_params_from_job_description (job_details )
187
+
155
188
else :
156
- raise NotImplemented ('Asynchronous fit not available' )
189
+ raise ValueError ('must specify training_job name' )
190
+
191
+ estimator = cls ._from_training_job (init_params , hp , image , sagemaker_session )
192
+ estimator .latest_training_job = _TrainingJob (sagemaker_session = sagemaker_session ,
193
+ training_job_name = init_params ['base_job_name' ])
194
+ estimator .latest_training_job .wait ()
195
+ return estimator
157
196
158
197
def deploy (self , initial_instance_count , instance_type , endpoint_name = None , ** kwargs ):
159
198
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
@@ -528,56 +567,24 @@ def hyperparameters(self):
528
567
return self ._json_encode_hyperparameters (self ._hyperparameters )
529
568
530
569
@classmethod
531
- def attach (cls , training_job_name , sagemaker_session = None , ** kwargs ):
532
- """Attach to an existing training job.
533
-
534
- Create an Estimator bound to an existing training job. After attaching, if
535
- the training job has a Complete status, it can be ``deploy()`` ed to create
536
- a SageMaker Endpoint and return a ``Predictor``.
537
-
538
- If the training job is in progress, attach will block and display log messages
539
- from the training job, until the training job completes.
540
-
541
- Args:
542
- training_job_name (str): The name of the training job to attach to.
543
- sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
544
- Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
545
- using the default AWS configuration chain.
546
- **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
547
-
548
- Returns:
549
- sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
550
- """
551
- sagemaker_session = sagemaker_session or Session ()
552
-
553
- if training_job_name is not None :
554
- job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
555
- init_params , hp , _ = cls ._prepare_estimator_params_from_job_description (job_details )
556
-
557
- else :
558
- # this case is only valid when called from inheriting class and then the class must declare framework
559
- if not hasattr (cls , '__framework_name__' ):
560
- raise ValueError ('must specify training_job name' )
561
- init_params = dict (kwargs )
562
- hp = init_params .pop ('hyperparameters' )
570
+ def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
563
571
564
572
# parameters for framework classes
565
573
framework_init_params = dict ()
566
- framework_init_params ['entry_point' ] = json .loads (hp .get (SCRIPT_PARAM_NAME ))
567
- framework_init_params ['source_dir' ] = json .loads (hp .get (DIR_PARAM_NAME ))
568
- framework_init_params ['enable_cloudwatch_metrics' ] = json .loads (hp .get (CLOUDWATCH_METRICS_PARAM_NAME ))
569
- framework_init_params ['container_log_level' ] = json .loads (hp .get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
574
+ framework_init_params ['entry_point' ] = json .loads (hyperparameters .get (SCRIPT_PARAM_NAME ))
575
+ framework_init_params ['source_dir' ] = json .loads (hyperparameters .get (DIR_PARAM_NAME ))
576
+ framework_init_params ['enable_cloudwatch_metrics' ] = json .loads (
577
+ hyperparameters .get (CLOUDWATCH_METRICS_PARAM_NAME ))
578
+ framework_init_params ['container_log_level' ] = json .loads (
579
+ hyperparameters .get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
570
580
571
581
# drop json and remove other SageMaker specific additions
572
- hyperparameters = {entry : json .loads (hp [entry ]) for entry in hp }
573
- framework_init_params ['hyperparameters' ] = hyperparameters
582
+ hp_map = {entry : json .loads (hyperparameters [entry ]) for entry in hyperparameters }
583
+ framework_init_params ['hyperparameters' ] = hp_map
574
584
575
585
init_params .update (framework_init_params )
576
586
577
587
estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
578
- estimator .latest_training_job = _TrainingJob (sagemaker_session = sagemaker_session ,
579
- training_job_name = init_params ['base_job_name' ])
580
- estimator .latest_training_job .wait ()
581
588
estimator .uploaded_code = UploadedCode (estimator .source_dir , estimator .entry_point )
582
589
return estimator
583
590
0 commit comments