Skip to content

Add support for async fit() #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 1, 2018
Merged

Add support for async fit() #59

merged 7 commits into from
Feb 1, 2018

Conversation

iquintero
Copy link
Contributor

@iquintero iquintero commented Jan 25, 2018

when calling fit(wait=False) it will return immediately. The training
job will carry on even if the process exits. by using attach() the
estimator can be retrieved by providing the training job name.

This fixes: #4

when calling fit(wait=False) it will return immediately. The training
job will carry on even if the process exits. by using attach() the
estimator can be retrieved by providing the training job name.
@iquintero iquintero requested a review from owen-t January 25, 2018 18:57
@iquintero
Copy link
Contributor Author

Note:

I have done a big refactor of how attach() works. This is the main change here, everything else is mostly tests.

@iquintero iquintero force-pushed the async_fit branch 3 times, most recently from 55abeba to a082319 Compare January 25, 2018 22:14
Also fixed the timeouts for all the async fit integ tests.
Previously we allowed 15 min timeout for training, and 20 min for
hosting.

With async fit the sections are split so we allow 5 min timeout for the
intial fit call and setup. And then 35 min for the attach() + hosting
calls. The total runtime is the same just split  differently for async
tests.
Fix the PCA and factorization machines async fit integration tests
and add an exception when running Tensorboard with async fit.
mvsusp
mvsusp previously requested changes Jan 29, 2018
Copy link
Contributor

@mvsusp mvsusp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the Readme.md as well.

@@ -64,6 +64,20 @@ def data_location(self, data_location):
data_location = data_location + '/'
self._data_location = data_location

@classmethod
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add Docstrings to this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

# and pass the correct name to the constructor.

for attribute, value in cls.__dict__.items():
if isinstance(value, hp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing dict to find a class attribute shows that _from_training_job should not be a class method but an instance method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine here, hyperparameters are implemented as descriptor objects.

a SageMaker Endpoint and return a ``Predictor``.

If the training job is in progress, attach will block and display log messages
from the training job, until the training job completes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some code examples here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add an example.

sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs is not in the method signature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will get rid of it. It used to be and I forgot to get rid of it here.

"""
sagemaker_session = sagemaker_session or Session()

if training_job_name is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not training_job_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do


# drop json and remove other SageMaker specific additions
hyperparameters = {entry: json.loads(hp[entry]) for entry in hp}
framework_init_params['hyperparameters'] = hyperparameters
hp_map = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hp is a map already. Maybe rename to deseriealized_hps or something in this lines?
My following up question is why is this method responsible to deserialize the hps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rename it :D

  • this was already done by attach(). I don't think its worth to move this deserialization to its own method.

time.sleep(20)
print("attaching now...")

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=35):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

35 minutes is too long. Can we make a test under 15 minutes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test itself is taking the same time as the synchronous fit one, the difference is this timeout is accounting for fit() + deploy(). If you notice the timeout is consistent with the synchronous tests that we have.

The difference is the call to

with timeout_and_delete_endpoint_by_name() is spending a lot of time on attach() this time is usually spent in fit() in the synchronous tests. In this case the fit() returns right away so it doesn't account for much of the runtime. So the test takes the same amount of time to complete.

time.sleep(20)
print("attaching now...")

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=35):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too long

# and pass the correct name to the constructor.

for attribute, value in cls.__dict__.items():
if isinstance(value, hp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine here, hyperparameters are implemented as descriptor objects.

@@ -152,8 +152,46 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
self.latest_training_job = _TrainingJob.start_new(self, inputs)
if wait:
self.latest_training_job.wait(logs=logs)

@classmethod
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation here, it's important. You're introducing a protocol that subclasses need to follow, so this should be documented.

Copy link
Contributor

@owen-t owen-t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like BYO is missing implementation.

In addition - please resolve the specific issues raised by MVS.

else:
raise NotImplemented('Asynchronous fit not available')
raise ValueError('must specify training_job name')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this isn't necessary, just let the underlying call fail.

@iquintero iquintero force-pushed the async_fit branch 2 times, most recently from ec18b1f to a21d2e6 Compare January 30, 2018 01:23
BYO was missing an implementation of _from_training_job(). This adds
that as well as an integration test to verify that.

Also addressed the PR comments and added information to the README.
@iquintero iquintero force-pushed the async_fit branch 2 times, most recently from 5bedd29 to 8d96cec Compare January 31, 2018 22:12
_prepare_init_params_from_job_description() is now a classmethod instead
of being a static method. Each class is responsible to implement their
specific logic to convert a training job description into arguments that
can be passed to its own __init__()
@iquintero iquintero dismissed mvsusp’s stale review February 1, 2018 18:01

comments already addressed. Got approval from main reviewer Owen too.

@iquintero iquintero merged commit e1d79d5 into aws:master Feb 1, 2018
@iquintero iquintero deleted the async_fit branch February 1, 2018 18:02
jalabort added a commit to hudl/sagemaker-python-sdk that referenced this pull request Mar 1, 2018
* Add data_type to hyperparameters (aws#54)

When we describe a training job the data type of the hyper parameters is
lost because we use a dict[str, str]. This adds a new field to
Hyperparameter so that we can convert the datatypes at runtime.

instead of validating with isinstance(), we cast the hp value to the type it
is meant to be. This enforces a "strongly typed" value. When we
deserialize from the API string responses it becomes easier to deal with
too.

* Add wrapper for LDA. (aws#56)

Update CHANGELOG and bump the version number.

* Add support for async fit() (aws#59)

when calling fit(wait=False) it will return immediately. The training
job will carry on even if the process exits. by using attach() the
estimator can be retrieved by providing the training job name.

_prepare_init_params_from_job_description() is now a classmethod instead
of being a static method. Each class is responsible to implement their
specific logic to convert a training job description into arguments that
can be passed to its own __init__()

* Fix Estimator role expansion (aws#68)

Instead of manually constructing the role ARN, use the IAM boto client
to do it. This properly expands service-roles and regular roles.

* Add FM and LDA to the documentation. (aws#66)

* Fix description of an argument of sagemaker.session.train (aws#69)

* Fix description of an argument of sagemaker.session.train

'input_config' should be an array which has channel objects.

* Add a link to the botocore docs

* Use 'list' instead of 'array' in the description

* Add ntm algorithm with doc, unit tests, integ tests (aws#73)

* JSON serializer: predictor.predict accepts dictionaries (aws#62)

Add support for serializing python dictionaries to json
Add prediction with dictionary in tf iris integ test

* Fixing timeouts for PCA async integration test. (aws#78)

Execute tf_cifar test without logs to eliminate delay to detect that job has finished.

* Fixes in LinearLearner and unit tests addition. (aws#77)

* Print out billable seconds after training completes (aws#30)

* Added: print out billable seconds after training completes

* Fixed: test_session.py to pass unit tests

* Fixed: removed offending tzlocal()

* Use sagemaker_timestamp when creating endpoint names in integration tests. (aws#81)

* Support TensorFlow-1.5.0 and MXNet-1.0.0  (aws#82)

* Update .gitignore to ignore pytest_cache.

* Support TensorFlow-1.5.0 and MXNet-1.0.0

* Update and refactor tests. Add tests for fw_utils.

* Fix typo.

* Update changelog for 1.1.0 (aws#85)
apacker pushed a commit to apacker/sagemaker-python-sdk that referenced this pull request Nov 15, 2018
Update tensorflow_resnet_cifar10_with_tensorboard.ipynb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Asynchronous fit
3 participants