Skip to content

Commit 79bac68

Browse files
author
Ignacio Quintero
committed
Add support for async fit()
when calling fit(wait=False) it will return immediately. The training job will carry on even if the process exits. by using attach() the estimator can be retrieved by providing the training job name.
1 parent 54b3830 commit 79bac68

File tree

10 files changed

+305
-112
lines changed

10 files changed

+305
-112
lines changed

src/sagemaker/amazon/amazon_estimator.py

+14
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ def data_location(self, data_location):
6464
data_location = data_location + '/'
6565
self._data_location = data_location
6666

67+
@classmethod
68+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
69+
70+
# The hyperparam names may not be the same as the class attribute that holds them,
71+
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
72+
# and pass the correct name to the constructor.
73+
74+
for attribute, value in cls.__dict__.items():
75+
if isinstance(value, hp):
76+
if value.name in hyperparameters:
77+
init_params[attribute] = hyperparameters[value.name]
78+
79+
return cls(sagemaker_session=sagemaker_session, **init_params)
80+
6781
def fit(self, records, mini_batch_size=None, **kwargs):
6882
"""Fit this Estimator on serialized Record objects, stored in S3.
6983

src/sagemaker/estimator.py

+49-42
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,47 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
152152
self.latest_training_job = _TrainingJob.start_new(self, inputs)
153153
if wait:
154154
self.latest_training_job.wait(logs=logs)
155+
156+
157+
@classmethod
158+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
159+
raise NotImplementedError()
160+
161+
@classmethod
162+
def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
163+
"""Attach to an existing training job.
164+
165+
Create an Estimator bound to an existing training job. After attaching, if
166+
the training job has a Complete status, it can be ``deploy()`` ed to create
167+
a SageMaker Endpoint and return a ``Predictor``.
168+
169+
If the training job is in progress, attach will block and display log messages
170+
from the training job, until the training job completes.
171+
172+
Args:
173+
training_job_name (str): The name of the training job to attach to.
174+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
175+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
176+
using the default AWS configuration chain.
177+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
178+
179+
Returns:
180+
sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
181+
"""
182+
sagemaker_session = sagemaker_session or Session()
183+
184+
if training_job_name is not None:
185+
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
186+
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
187+
155188
else:
156-
raise NotImplemented('Asynchronous fit not available')
189+
raise ValueError('must specify training_job name')
190+
191+
estimator = cls._from_training_job(init_params, hp, image, sagemaker_session)
192+
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
193+
training_job_name=init_params['base_job_name'])
194+
estimator.latest_training_job.wait()
195+
return estimator
157196

158197
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
159198
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
@@ -528,56 +567,24 @@ def hyperparameters(self):
528567
return self._json_encode_hyperparameters(self._hyperparameters)
529568

530569
@classmethod
531-
def attach(cls, training_job_name, sagemaker_session=None, **kwargs):
532-
"""Attach to an existing training job.
533-
534-
Create an Estimator bound to an existing training job. After attaching, if
535-
the training job has a Complete status, it can be ``deploy()`` ed to create
536-
a SageMaker Endpoint and return a ``Predictor``.
537-
538-
If the training job is in progress, attach will block and display log messages
539-
from the training job, until the training job completes.
540-
541-
Args:
542-
training_job_name (str): The name of the training job to attach to.
543-
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
544-
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
545-
using the default AWS configuration chain.
546-
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
547-
548-
Returns:
549-
sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
550-
"""
551-
sagemaker_session = sagemaker_session or Session()
552-
553-
if training_job_name is not None:
554-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
555-
init_params, hp, _ = cls._prepare_estimator_params_from_job_description(job_details)
556-
557-
else:
558-
# this case is only valid when called from inheriting class and then the class must declare framework
559-
if not hasattr(cls, '__framework_name__'):
560-
raise ValueError('must specify training_job name')
561-
init_params = dict(kwargs)
562-
hp = init_params.pop('hyperparameters')
570+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
563571

564572
# parameters for framework classes
565573
framework_init_params = dict()
566-
framework_init_params['entry_point'] = json.loads(hp.get(SCRIPT_PARAM_NAME))
567-
framework_init_params['source_dir'] = json.loads(hp.get(DIR_PARAM_NAME))
568-
framework_init_params['enable_cloudwatch_metrics'] = json.loads(hp.get(CLOUDWATCH_METRICS_PARAM_NAME))
569-
framework_init_params['container_log_level'] = json.loads(hp.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
574+
framework_init_params['entry_point'] = json.loads(hyperparameters.get(SCRIPT_PARAM_NAME))
575+
framework_init_params['source_dir'] = json.loads(hyperparameters.get(DIR_PARAM_NAME))
576+
framework_init_params['enable_cloudwatch_metrics'] = json.loads(
577+
hyperparameters.get(CLOUDWATCH_METRICS_PARAM_NAME))
578+
framework_init_params['container_log_level'] = json.loads(
579+
hyperparameters.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
570580

571581
# drop json and remove other SageMaker specific additions
572-
hyperparameters = {entry: json.loads(hp[entry]) for entry in hp}
573-
framework_init_params['hyperparameters'] = hyperparameters
582+
hp_map = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
583+
framework_init_params['hyperparameters'] = hp_map
574584

575585
init_params.update(framework_init_params)
576586

577587
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
578-
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
579-
training_job_name=init_params['base_job_name'])
580-
estimator.latest_training_job.wait()
581588
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point)
582589
return estimator
583590

src/sagemaker/mxnet/estimator.py

+3-34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sagemaker.estimator import Framework
1515
from sagemaker.fw_utils import create_image_uri, framework_name_from_image
1616
from sagemaker.mxnet.model import MXNetModel
17-
from sagemaker.session import Session
1817

1918

2019
class MXNet(Framework):
@@ -83,42 +82,12 @@ def create_model(self, model_server_workers=None):
8382
sagemaker_session=self.sagemaker_session)
8483

8584
@classmethod
86-
def attach(cls, training_job_name, sagemaker_session=None):
87-
"""Attach to an existing training job.
88-
89-
Create an ``Estimator`` bound to an existing training job. After attaching, if
90-
the training job is in a Complete status, it can be ``deploy``ed to create
91-
a SageMaker ``Endpoint`` and return a ``Predictor``.
92-
93-
If the training job is in progress, attach will block and display log messages
94-
from the training job, until the training job completes.
95-
96-
Args:
97-
training_job_name (str): The name of the training job to attach to.
98-
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
99-
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
100-
using the default AWS configuration chain.
101-
102-
Returns:
103-
sagemaker.mxnet.estimator.MXNet: ``Estimator`` with the attached training job.
104-
105-
Raises:
106-
ValueError: If `training_job_name` is None or the image name does not match the framework.
107-
"""
108-
sagemaker_session = sagemaker_session or Session()
109-
110-
if training_job_name is None:
111-
raise ValueError("must specify training_job name")
112-
113-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
114-
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
115-
116-
init_params.update({'hyperparameters': hp})
117-
85+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
11886
framework, py_version = framework_name_from_image(image)
11987
init_params.update({'py_version': py_version})
12088

89+
training_job_name = init_params['base_job_name']
12190
if framework != cls.__framework_name__:
12291
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
12392

124-
return super(MXNet, cls).attach(training_job_name=None, sagemaker_session=sagemaker_session, **init_params)
93+
return super(MXNet, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)

src/sagemaker/tensorflow/estimator.py

+5-35
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import sagemaker.tensorflow
2121
from sagemaker.estimator import Framework
2222
from sagemaker.fw_utils import create_image_uri, framework_name_from_image
23-
from sagemaker.session import Session
2423
from sagemaker.tensorflow.model import TensorFlowModel
2524

2625
logging.basicConfig()
@@ -166,48 +165,19 @@ def fit_super():
166165
fit_super()
167166

168167
@classmethod
169-
def attach(cls, training_job_name, sagemaker_session=None):
170-
"""Attach to an existing training job.
171-
172-
Create an ``Estimator`` bound to an existing training job. After attaching, if
173-
the training job is in a Complete status, it can be ``deploy``ed to create
174-
a SageMaker ``Endpoint`` and return a ``Predictor``.
175-
176-
If the training job is in progress, attach will block and display log messages
177-
from the training job, until the training job completes.
178-
179-
Args:
180-
training_job_name (str): The name of the training job to attach to.
181-
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
182-
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
183-
using the default AWS configuration chain.
184-
185-
Returns:
186-
sagemaker.tensorflow.estimator.TensorFlow: ``Estimator`` with the attached training job.
187-
188-
Raises:
189-
ValueError: If `training_job_name` is None or the image name does not match the framework.
190-
"""
191-
sagemaker_session = sagemaker_session or Session()
192-
193-
if training_job_name is None:
194-
raise ValueError("must specify training_job name")
195-
196-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
197-
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
198-
199-
updated_params = cls._update_init_params(hp, ['checkpoint_path', 'training_steps', 'evaluation_steps'])
168+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
169+
updated_params = cls._update_init_params(hyperparameters,
170+
['checkpoint_path', 'training_steps', 'evaluation_steps'])
200171
init_params.update(updated_params)
201172

202-
init_params.update({'hyperparameters': hp})
203-
204173
framework, py_version = framework_name_from_image(image)
205174
init_params.update({'py_version': py_version})
175+
training_job_name = init_params['base_job_name']
206176

207177
if framework != cls.__framework_name__:
208178
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
209179

210-
return super(TensorFlow, cls).attach(training_job_name=None, sagemaker_session=sagemaker_session, **init_params)
180+
return super(TensorFlow, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)
211181

212182
def train_image(self):
213183
"""Return the Docker image to use for training.

tests/integ/test_kmeans.py

+50
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616

1717
import boto3
1818
import os
19+
import time
1920

2021
import sagemaker
2122
from sagemaker import KMeans, KMeansModel
2223
from sagemaker.utils import name_from_base
2324
from tests.integ import DATA_DIR, REGION
2425
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2526

27+
import pytest
2628

29+
30+
@pytest.mark.skip(reason="no way of currently testing this")
2731
def test_kmeans():
2832

2933
with timeout(minutes=15):
@@ -60,3 +64,49 @@ def test_kmeans():
6064
for record in result:
6165
assert record.label["closest_cluster"] is not None
6266
assert record.label["distance_to_cluster"] is not None
67+
68+
69+
def test_async_kmeans():
70+
71+
training_job_name = ""
72+
endpoint_name = name_from_base('kmeans')
73+
74+
with timeout(minutes=15):
75+
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
76+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
77+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
78+
79+
# Load the data into memory as numpy arrays
80+
with gzip.open(data_path, 'rb') as f:
81+
train_set, _, _ = pickle.load(f, **pickle_args)
82+
83+
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
84+
train_instance_type='ml.c4.xlarge',
85+
k=10, sagemaker_session=sagemaker_session, base_job_name='test-kmeans')
86+
87+
kmeans.init_method = 'random'
88+
kmeans.max_iterators = 1
89+
kmeans.tol = 1
90+
kmeans.num_trials = 1
91+
kmeans.local_init_method = 'kmeans++'
92+
kmeans.half_life_time_size = 1
93+
kmeans.epochs = 1
94+
kmeans.center_factor = 1
95+
96+
kmeans.fit(kmeans.record_set(train_set[0][:100]), wait=False)
97+
training_job_name = kmeans.latest_training_job.name
98+
99+
print("Detached from training job. Will re-attach in 20 seconds")
100+
time.sleep(20)
101+
print("attaching now...")
102+
103+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
104+
estimator = KMeans.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
105+
model = KMeansModel(estimator.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session)
106+
predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
107+
result = predictor.predict(train_set[0][:10])
108+
109+
assert len(result) == 10
110+
for record in result:
111+
assert record.label["closest_cluster"] is not None
112+
assert record.label["distance_to_cluster"] is not None

0 commit comments

Comments
 (0)