Skip to content

Commit ec18b1f

Browse files
author
Ignacio Quintero
committed
Add BYOA implementation and missing docs.
BYO was missing an implementation of _from_training_job(). This adds that as well as an integration test to verify that. Also addressed the PR comments and added information to the README.
1 parent 0e50790 commit ec18b1f

File tree

6 files changed

+175
-18
lines changed

6 files changed

+175
-18
lines changed

README.rst

+22-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ SageMaker Python SDK provides several high-level abstractions for working with A
9797
- **Estimators**: Encapsulate training on SageMaker. Can be ``fit()`` to run training, then the resulting model ``deploy()`` ed to a SageMaker Endpoint.
9898
- **Models**: Encapsulate built ML models. Can be ``deploy()`` ed to a SageMaker Endpoint.
9999
- **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker Endpoint.
100-
- **Session**: Provides a collection of convience methods for working with SageMaker resources.
100+
- **Session**: Provides a collection of convenience methods for working with SageMaker resources.
101101

102102
Estimator and Model implementations for MXNet, TensorFlow, and Amazon ML algorithms are included. There's also an Estimator that runs SageMaker compatible custom Docker containers, allowing you to run your own ML algorithms via SageMaker Python SDK.
103103

@@ -1149,7 +1149,8 @@ Optional arguments
11491149
''''''''''''''''''
11501150
11511151
- ``wait (bool)``: Defaults to True, whether to block and wait for the
1152-
training script to complete before returning.
1152+
training script to complete before returning. If set to False, it will return immediately, and can later be
1153+
`attach`ed to.
11531154
- ``logs (bool)``: Defaults to True, whether to show logs produced by training
11541155
job in the Python session. Only meaningful when wait is True.
11551156
- ``run_tensorboard_locally (bool)``: Defaults to False. Executes TensorBoard in a different
@@ -1178,9 +1179,25 @@ the ``TensorFlow`` estimator parameter ``training_steps`` is finished or when th
11781179
job execution time reaches the ``TensorFlow`` estimator parameter ``train_max_run``.
11791180
11801181
When the training job finishes, a `TensorFlow serving <https://www.tensorflow.org/serving/serving_basic>`_
1181-
with the result of the training is generated and saved to the S3 location define by
1182+
with the result of the training is generated and saved to the S3 location defined by
11821183
the ``TensorFlow`` estimator parameter ``output_path``.
11831184
1185+
1186+
If the ``wait=False`` flag is passed to ``fit``, then it will return immediately. The training job will continue running
1187+
asynchronously. At a later time, a Tensorflow Estimator can be obtained by attaching to the existing training job. If
1188+
the training job is not finished it will start showing the standard output of training and wait until it completes.
1189+
After attaching, the estimator can be deployed as usual.
1190+
1191+
.. code:: python
1192+
1193+
tf_estimator.fit(your_input_data, wait=False)
1194+
training_job_name = tf_estimator.latest_training_job.name
1195+
1196+
# after some time, or in a separate python notebook, we can attach to it again.
1197+
1198+
tf_estimator = TensorFlow.attach(training_job_name=training_job_name)
1199+
1200+
11841201
The evaluation process
11851202
""""""""""""""""""""""
11861203
@@ -1244,6 +1261,8 @@ You can access TensorBoard locally at http://localhost:6006 or using your SakeMa
12441261
`https*workspace_base_url*proxy/6006/ <proxy/6006/>`_ (TensorBoard will not work if you forget to put the slash,
12451262
'/', in end of the url). If TensorBoard started on a different port, adjust these URLs to match.
12461263
1264+
Note that TensorBoard is not supported when passing wait=False to ``fit``.
1265+
12471266
12481267
Deploying TensorFlow Serving models
12491268
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

src/sagemaker/amazon/amazon_estimator.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,21 @@ def data_location(self, data_location):
6666

6767
@classmethod
6868
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
69+
"""Create an Estimator from existing training job data.
70+
71+
Args:
72+
init_params (dict): The init_params the training job was created with.
73+
hyperparameters (dict): The hyperparameters the training job was created with.
74+
image (str): Container image (if any) the training job was created with
75+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
76+
77+
Returns: An instance of the calling Estimator Class.
78+
79+
"""
6980

7081
# The hyperparam names may not be the same as the class attribute that holds them,
7182
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
7283
# and pass the correct name to the constructor.
73-
7484
for attribute, value in cls.__dict__.items():
7585
if isinstance(value, hp):
7686
if value.name in hyperparameters:

src/sagemaker/estimator.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,26 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
155155

156156
@classmethod
157157
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
158+
"""Create an Estimator from existing training job data.
159+
160+
Args:
161+
init_params (dict): The init_params the training job was created with.
162+
hyperparameters (dict): The hyperparameters the training job was created with.
163+
image (str): Container image (if any) the training job was created with
164+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
165+
166+
Returns: An instance of the calling Estimator Class.
167+
168+
"""
158169
raise NotImplementedError()
159170

160171
@classmethod
161172
def attach(cls, training_job_name, sagemaker_session=None):
162173
"""Attach to an existing training job.
163174
164-
Create an Estimator bound to an existing training job. After attaching, if
165-
the training job has a Complete status, it can be ``deploy()`` ed to create
175+
Create an Estimator bound to an existing training job, each subclass is responsible to implement
176+
``from_training_job()`` as this method delegates the actual Estimator creation to it. After
177+
attaching, if the training job has a Complete status, it can be ``deploy()`` ed to create
166178
a SageMaker Endpoint and return a ``Predictor``.
167179
168180
If the training job is in progress, attach will block and display log messages
@@ -173,17 +185,22 @@ def attach(cls, training_job_name, sagemaker_session=None):
173185
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
174186
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
175187
using the default AWS configuration chain.
176-
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Estimator` constructor.
188+
189+
Examples:
190+
>>> my_estimator.fit(wait=False)
191+
>>> training_job_name = my_estimator.latest_training_job.name
192+
Later on:
193+
>>> attached_estimator = Estimator.attach(training_job_name)
194+
>>> attached_estimator.deploy()
177195
178196
Returns:
179-
sagemaker.estimator.Framework: ``Estimator`` with the attached training job.
197+
Instance of the calling ``Estimator`` Class with the attached training job.
180198
"""
181199
sagemaker_session = sagemaker_session or Session()
182200

183-
if training_job_name is not None:
201+
if training_job_name:
184202
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
185203
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
186-
187204
else:
188205
raise ValueError('must specify training_job name')
189206

@@ -460,6 +477,25 @@ def predict_wrapper(endpoint, session):
460477
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
461478
predictor_cls=predictor_cls, **kwargs)
462479

480+
@classmethod
481+
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
482+
"""Create an Estimator from existing training job data.
483+
484+
Args:
485+
init_params (dict): The init_params the training job was created with.
486+
hyperparameters (dict): The hyperparameters the training job was created with.
487+
image (str): Container image (if any) the training job was created with
488+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
489+
490+
Returns: An instance of the calling Estimator Class.
491+
492+
"""
493+
494+
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
495+
cls.set_hyperparameters(**hyperparameters)
496+
497+
return estimator
498+
463499

464500
class Framework(EstimatorBase):
465501
"""Base class that cannot be instantiated directly.
@@ -567,6 +603,17 @@ def hyperparameters(self):
567603

568604
@classmethod
569605
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
606+
"""Create an Estimator from existing training job data.
607+
608+
Args:
609+
init_params (dict): The init_params the training job was created with.
610+
hyperparameters (dict): The hyperparameters the training job was created with.
611+
image (str): Container image (if any) the training job was created with
612+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
613+
614+
Returns: An instance of the calling Estimator Class.
615+
616+
"""
570617

571618
# parameters for framework classes
572619
framework_init_params = dict()
@@ -578,8 +625,8 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi
578625
hyperparameters.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
579626

580627
# drop json and remove other SageMaker specific additions
581-
hp_map = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
582-
framework_init_params['hyperparameters'] = hp_map
628+
deserialized_hps = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
629+
framework_init_params['hyperparameters'] = deserialized_hps
583630

584631
init_params.update(framework_init_params)
585632

src/sagemaker/mxnet/estimator.py

+12
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ def create_model(self, model_server_workers=None):
8383

8484
@classmethod
8585
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
86+
"""Create an Estimator from existing training job data.
87+
88+
Args:
89+
init_params (dict): The init_params the training job was created with.
90+
hyperparameters (dict): The hyperparameters the training job was created with.
91+
image (str): Container image (if any) the training job was created with
92+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
93+
94+
Returns: An instance of the calling Estimator Class.
95+
96+
"""
97+
8698
framework, py_version = framework_name_from_image(image)
8799
init_params.update({'py_version': py_version})
88100

src/sagemaker/tensorflow/estimator.py

+12
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ def fit_super():
169169

170170
@classmethod
171171
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
172+
"""Create an Estimator from existing training job data.
173+
174+
Args:
175+
init_params (dict): The init_params the training job was created with.
176+
hyperparameters (dict): The hyperparameters the training job was created with.
177+
image (str): Container image (if any) the training job was created with
178+
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
179+
180+
Returns: An instance of the calling Estimator Class.
181+
182+
"""
183+
172184
updated_params = cls._update_init_params(hyperparameters,
173185
['checkpoint_path', 'training_steps', 'evaluation_steps'])
174186
init_params.update(updated_params)

tests/integ/test_byo_estimator.py

+63-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
3030

3131

32+
def fm_serializer(data):
33+
js = {'instances': []}
34+
for row in data:
35+
js['instances'].append({'features': row.tolist()})
36+
return json.dumps(js)
37+
38+
3239
def test_byo_estimator():
3340
"""Use Factorization Machines algorithm as an example here.
3441
@@ -79,12 +86,6 @@ def test_byo_estimator():
7986

8087
endpoint_name = name_from_base('byo')
8188

82-
def fm_serializer(data):
83-
js = {'instances': []}
84-
for row in data:
85-
js['instances'].append({'features': row.tolist()})
86-
return json.dumps(js)
87-
8889
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
8990
model = estimator.create_model()
9091
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
@@ -97,3 +98,59 @@ def fm_serializer(data):
9798
assert len(result['predictions']) == 10
9899
for prediction in result['predictions']:
99100
assert prediction['score'] is not None
101+
102+
103+
def test_async_byo_estimator():
104+
image_name = registry(REGION) + "/factorization-machines:1"
105+
endpoint_name = name_from_base('byo')
106+
training_job_name = ""
107+
108+
with timeout(minutes=5):
109+
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
110+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
111+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
112+
113+
with gzip.open(data_path, 'rb') as f:
114+
train_set, _, _ = pickle.load(f, **pickle_args)
115+
116+
# take 100 examples for faster execution
117+
vectors = np.array([t.tolist() for t in train_set[0][:100]]).astype('float32')
118+
labels = np.where(np.array([t.tolist() for t in train_set[1][:100]]) == 0, 1.0, 0.0).astype('float32')
119+
120+
buf = io.BytesIO()
121+
write_numpy_to_dense_tensor(buf, vectors, labels)
122+
buf.seek(0)
123+
124+
bucket = sagemaker_session.default_bucket()
125+
prefix = 'test_byo_estimator'
126+
key = 'recordio-pb-data'
127+
boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf)
128+
s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key)
129+
130+
estimator = Estimator(image_name=image_name,
131+
role='SageMakerRole', train_instance_count=1,
132+
train_instance_type='ml.c4.xlarge',
133+
sagemaker_session=sagemaker_session, base_job_name='test-byo')
134+
135+
estimator.set_hyperparameters(num_factors=10,
136+
feature_dim=784,
137+
mini_batch_size=100,
138+
predictor_type='binary_classifier')
139+
140+
# training labels must be 'float32'
141+
estimator.fit({'train': s3_train_data}, wait=False)
142+
training_job_name = estimator.latest_training_job.name
143+
144+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=30):
145+
estimator = Estimator.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
146+
model = estimator.create_model()
147+
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
148+
predictor.serializer = fm_serializer
149+
predictor.content_type = 'application/json'
150+
predictor.deserializer = sagemaker.predictor.json_deserializer
151+
152+
result = predictor.predict(train_set[0][:10])
153+
154+
assert len(result['predictions']) == 10
155+
for prediction in result['predictions']:
156+
assert prediction['score'] is not None

0 commit comments

Comments
 (0)