-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support for async fit() #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
79bac68
ec53250
e595cc5
0e50790
24ae51e
7918d8b
03fa7b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,8 +152,60 @@ def fit(self, inputs, wait=True, logs=True, job_name=None): | |
self.latest_training_job = _TrainingJob.start_new(self, inputs) | ||
if wait: | ||
self.latest_training_job.wait(logs=logs) | ||
else: | ||
raise NotImplemented('Asynchronous fit not available') | ||
|
||
@classmethod | ||
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add documentation here, it's important. You're introducing a protocol that subclasses need to follow, so this should be documented. |
||
"""Create an Estimator from existing training job data. | ||
|
||
Args: | ||
init_params (dict): The init_params the training job was created with. | ||
hyperparameters (dict): The hyperparameters the training job was created with. | ||
image (str): Container image (if any) the training job was created with | ||
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator. | ||
|
||
Returns: An instance of the calling Estimator Class. | ||
|
||
""" | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def attach(cls, training_job_name, sagemaker_session=None, job_details=None): | ||
"""Attach to an existing training job. | ||
|
||
Create an Estimator bound to an existing training job, each subclass is responsible to implement | ||
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training | ||
job description to the arguments that the class constructor expects. After attaching, if the training job has a | ||
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. | ||
|
||
If the training job is in progress, attach will block and display log messages | ||
from the training job, until the training job completes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some code examples here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add an example. |
||
|
||
Args: | ||
training_job_name (str): The name of the training job to attach to. | ||
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with | ||
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one | ||
using the default AWS configuration chain. | ||
|
||
Examples: | ||
>>> my_estimator.fit(wait=False) | ||
>>> training_job_name = my_estimator.latest_training_job.name | ||
Later on: | ||
>>> attached_estimator = Estimator.attach(training_job_name) | ||
>>> attached_estimator.deploy() | ||
|
||
Returns: | ||
Instance of the calling ``Estimator`` Class with the attached training job. | ||
""" | ||
sagemaker_session = sagemaker_session or Session() | ||
|
||
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) | ||
init_params = cls._prepare_init_params_from_job_description(job_details) | ||
|
||
estimator = cls(sagemaker_session=sagemaker_session, **init_params) | ||
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session, | ||
training_job_name=init_params['base_job_name']) | ||
estimator.latest_training_job.wait() | ||
return estimator | ||
|
||
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs): | ||
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object. | ||
|
@@ -202,21 +254,33 @@ def create_model(self, **kwargs): | |
""" | ||
pass | ||
|
||
@staticmethod | ||
def _prepare_estimator_params_from_job_description(job_details): | ||
estimator_params = dict() | ||
@classmethod | ||
def _prepare_init_params_from_job_description(cls, job_details): | ||
"""Convert the job description to init params that can be handled by the class constructor | ||
|
||
Args: | ||
job_details: the returned job details from a describe_training_job API call. | ||
|
||
Returns: | ||
dictionary: The transformed init_params | ||
|
||
""" | ||
init_params = dict() | ||
|
||
estimator_params['role'] = job_details['RoleArn'] | ||
estimator_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount'] | ||
estimator_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType'] | ||
estimator_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB'] | ||
estimator_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds'] | ||
estimator_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode'] | ||
estimator_params['base_job_name'] = job_details['TrainingJobName'] | ||
estimator_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath'] | ||
estimator_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId'] | ||
init_params['role'] = job_details['RoleArn'] | ||
init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount'] | ||
init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType'] | ||
init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB'] | ||
init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds'] | ||
init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode'] | ||
init_params['base_job_name'] = job_details['TrainingJobName'] | ||
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath'] | ||
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId'] | ||
|
||
return estimator_params, job_details['HyperParameters'], job_details['AlgorithmSpecification']['TrainingImage'] | ||
init_params['hyperparameters'] = job_details['HyperParameters'] | ||
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage'] | ||
|
||
return init_params | ||
|
||
def delete_endpoint(self): | ||
"""Delete an Amazon SageMaker ``Endpoint``. | ||
|
@@ -333,7 +397,8 @@ class Estimator(EstimatorBase): | |
|
||
def __init__(self, image_name, role, train_instance_count, train_instance_type, | ||
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', | ||
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None): | ||
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, | ||
hyperparameters=None): | ||
"""Initialize an ``Estimator`` instance. | ||
|
||
Args: | ||
|
@@ -365,9 +430,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type, | |
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with | ||
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one | ||
using the default AWS configuration chain. | ||
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with. | ||
""" | ||
self.image_name = image_name | ||
self.hyperparam_dict = {} | ||
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} | ||
super(Estimator, self).__init__(role, train_instance_count, train_instance_type, | ||
train_volume_size, train_max_run, input_mode, | ||
output_path, output_kms_key, base_job_name, sagemaker_session) | ||
|
@@ -422,6 +488,22 @@ def predict_wrapper(endpoint, session): | |
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session, | ||
predictor_cls=predictor_cls, **kwargs) | ||
|
||
@classmethod | ||
def _prepare_init_params_from_job_description(cls, job_details): | ||
"""Convert the job description to init params that can be handled by the class constructor | ||
|
||
Args: | ||
job_details: the returned job details from a describe_training_job API call. | ||
|
||
Returns: | ||
dictionary: The transformed init_params | ||
|
||
""" | ||
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details) | ||
|
||
init_params['image_name'] = init_params.pop('image') | ||
return init_params | ||
|
||
|
||
class Framework(EstimatorBase): | ||
"""Base class that cannot be instantiated directly. | ||
|
@@ -528,12 +610,37 @@ def hyperparameters(self): | |
return self._json_encode_hyperparameters(self._hyperparameters) | ||
|
||
@classmethod | ||
def attach(cls, training_job_name, sagemaker_session=None, **kwargs): | ||
def _prepare_init_params_from_job_description(cls, job_details): | ||
"""Convert the job description to init params that can be handled by the class constructor | ||
|
||
Args: | ||
job_details: the returned job details from a describe_training_job API call. | ||
|
||
Returns: | ||
dictionary: The transformed init_params | ||
|
||
""" | ||
init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details) | ||
|
||
init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME)) | ||
init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME)) | ||
init_params['enable_cloudwatch_metrics'] = json.loads( | ||
init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME)) | ||
init_params['container_log_level'] = json.loads( | ||
init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME)) | ||
|
||
init_params['hyperparameters'] = {k: json.loads(v) for k, v in init_params['hyperparameters'].items()} | ||
|
||
return init_params | ||
|
||
@classmethod | ||
def attach(cls, training_job_name, sagemaker_session=None): | ||
"""Attach to an existing training job. | ||
|
||
Create an Estimator bound to an existing training job. After attaching, if | ||
the training job has a Complete status, it can be ``deploy()`` ed to create | ||
a SageMaker Endpoint and return a ``Predictor``. | ||
Create an Estimator bound to an existing training job, each subclass is responsible to implement | ||
``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training | ||
job description to the arguments that the class constructor expects. After attaching, if the training job has a | ||
Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. | ||
|
||
If the training job is in progress, attach will block and display log messages | ||
from the training job, until the training job completes. | ||
|
@@ -543,41 +650,18 @@ def attach(cls, training_job_name, sagemaker_session=None, **kwargs): | |
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with | ||
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one | ||
using the default AWS configuration chain. | ||
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor. | ||
|
||
Examples: | ||
>>> my_estimator.fit(wait=False) | ||
>>> training_job_name = my_estimator.latest_training_job.name | ||
Later on: | ||
>>> attached_estimator = Estimator.attach(training_job_name) | ||
>>> attached_estimator.deploy() | ||
|
||
Returns: | ||
sagemaker.estimator.Framework: ``Estimator`` with the attached training job. | ||
Instance of the calling ``Estimator`` Class with the attached training job. | ||
""" | ||
sagemaker_session = sagemaker_session or Session() | ||
|
||
if training_job_name is not None: | ||
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) | ||
init_params, hp, _ = cls._prepare_estimator_params_from_job_description(job_details) | ||
|
||
else: | ||
# this case is only valid when called from inheriting class and then the class must declare framework | ||
if not hasattr(cls, '__framework_name__'): | ||
raise ValueError('must specify training_job name') | ||
init_params = dict(kwargs) | ||
hp = init_params.pop('hyperparameters') | ||
|
||
# parameters for framework classes | ||
framework_init_params = dict() | ||
framework_init_params['entry_point'] = json.loads(hp.get(SCRIPT_PARAM_NAME)) | ||
framework_init_params['source_dir'] = json.loads(hp.get(DIR_PARAM_NAME)) | ||
framework_init_params['enable_cloudwatch_metrics'] = json.loads(hp.get(CLOUDWATCH_METRICS_PARAM_NAME)) | ||
framework_init_params['container_log_level'] = json.loads(hp.get(CONTAINER_LOG_LEVEL_PARAM_NAME)) | ||
|
||
# drop json and remove other SageMaker specific additions | ||
hyperparameters = {entry: json.loads(hp[entry]) for entry in hp} | ||
framework_init_params['hyperparameters'] = hyperparameters | ||
|
||
init_params.update(framework_init_params) | ||
|
||
estimator = cls(sagemaker_session=sagemaker_session, **init_params) | ||
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session, | ||
training_job_name=init_params['base_job_name']) | ||
estimator.latest_training_job.wait() | ||
estimator = super(Framework, cls).attach(training_job_name, sagemaker_session) | ||
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point) | ||
return estimator | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing dict to find a class attribute shows that
_from_training_job
should not be a class method but an instance method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine here, hyperparameters are implemented as descriptor objects.