Skip to content

Commit 704cd31

Browse files
authored
Add support for Batch Transform and update README with TF Pipe Mode (#298)
1 parent 5ea3fd0 commit 704cd31

13 files changed

+1093
-34
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
CHANGELOG
33
=========
44

5+
1.7.0
6+
=====
7+
8+
* feature: Transformer: add support for batch transform jobs
9+
* feature: Documentation: add instructions for using Pipe Mode with TensorFlow
10+
511
1.6.1
612
=====
713

README.rst

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ Table of Contents
3232
7. `AWS SageMaker Estimators <#aws-sagemaker-estimators>`__
3333
8. `BYO Docker Containers with SageMaker Estimators <#byo-docker-containers-with-sagemaker-estimators>`__
3434
9. `SageMaker Automatic Model Tuning <#sagemaker-automatic-model-tuning>`__
35-
10. `BYO Model <#byo-model>`__
35+
10. `SageMaker Batch Transform <#sagemaker-batch-transform>`__
36+
11. `BYO Model <#byo-model>`__
3637

3738

3839
Getting SageMaker Python SDK
@@ -50,7 +51,7 @@ You can install from source by cloning this repository and issuing a pip install
5051

5152
git clone https://github.com/aws/sagemaker-python-sdk.git
5253
python setup.py sdist
53-
pip install dist/sagemaker-1.6.1.tar.gz
54+
pip install dist/sagemaker-1.7.0.tar.gz
5455

5556
Supported Python versions
5657
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -375,6 +376,39 @@ For more detailed explanations of the classes that this library provides for aut
375376
- `API docs for analytics classes <https://sagemaker.readthedocs.io/en/latest/analytics.html>`__
376377

377378

379+
SageMaker Batch Transform
380+
-------------------------
381+
382+
Once you have a trained model, you can use Amazon SageMaker Batch Transform to perform inferences with the model.
383+
Batch Transform manages all compute resources necessary, including launching instances to deploy endpoints and deleting them afterward.
384+
You can read more about SageMaker Batch Transform in the `AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-batch.html>`__.
385+
386+
If you have trained the model using a SageMaker Python SDK Estimator, you can simply invoke ``transformer()`` to create a ``Transformer`` for the training job:
387+
388+
.. code:: python
389+
390+
transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge')
391+
392+
Alternatively, if you already have a SageMaker Model, you can instantiate a ``Transformer`` directly with its constructor:
393+
394+
.. code:: python
395+
396+
transformer = Transformer(model_name='my-previously-trained-model',
397+
instance_count=1,
398+
instance_type='ml.m4.xlarge')
399+
400+
For a full list of the possible options to configure through either of these methods, please refer to the API docs for `Estimator <https://sagemaker.readthedocs.io/en/latest/estimators.html#sagemaker.estimator.Estimator.transformer>`__ or `Transformer <https://sagemaker.readthedocs.io/en/latest/transformer.html#sagemaker.transformer.Transformer>`__.
401+
402+
Once you've created a ``Transformer`` object, you can invoke ``transform()`` to being a batch transform job with the S3 location of your data.
403+
You can also specify other attributes about your data, such as the content type.
404+
405+
.. code:: python
406+
407+
transformer.transform('s3://my-bucket/batch-transform-input')
408+
409+
For more details about what can be specified here, please refer to the `API docs <https://sagemaker.readthedocs.io/en/latest/transformer.html#sagemaker.transformer.Transformer.transform>`__.
410+
411+
378412
FAQ
379413
---
380414

@@ -422,7 +456,7 @@ Example code using the TensorFlow predictor:
422456

423457

424458
BYO Model
425-
-----------------------------------------------
459+
---------
426460
You can also create an endpoint from an existing model rather than training one - i.e. bring your own model.
427461

428462
First, package the files for the trained model into a ``.tar.gz`` file, and upload the archive to S3.

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The SageMaker Python SDK consists of a few primary interfaces:
1616
estimators
1717
tuner
1818
predictors
19+
transformer
1920
session
2021
model
2122
analytics

doc/transformer.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Transformer
2+
-----------
3+
4+
.. autoclass:: sagemaker.transformer.Transformer
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def read(fname):
2323

2424

2525
setup(name="sagemaker",
26-
version="1.6.1",
26+
version="1.7.0",
2727
description="Open source library for training and deploying models on Amazon SageMaker.",
2828
packages=find_packages('src'),
2929
package_dir={'': 'src'},

src/sagemaker/estimator.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from sagemaker.predictor import RealTimePredictor
3131
from sagemaker.session import Session
3232
from sagemaker.session import s3_input
33-
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
33+
from sagemaker.transformer import Transformer
34+
from sagemaker.utils import base_name_from_image, name_from_base, name_from_image, get_config_value
3435

3536

3637
class EstimatorBase(with_metaclass(ABCMeta, object)):
@@ -253,8 +254,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kw
253254
sagemaker.predictor.RealTimePredictor: A predictor that provides a ``predict()`` method,
254255
which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
255256
"""
256-
if not self.latest_training_job:
257-
raise RuntimeError('Estimator has not been fit yet.')
257+
self._ensure_latest_training_job()
258258
endpoint_name = endpoint_name or self.latest_training_job.name
259259
self.deploy_instance_type = instance_type
260260
return self.create_model(**kwargs).deploy(
@@ -314,10 +314,43 @@ def delete_endpoint(self):
314314
Raises:
315315
ValueError: If the endpoint does not exist.
316316
"""
317-
if self.latest_training_job is None:
318-
raise ValueError('Endpoint was not created yet')
317+
self._ensure_latest_training_job(error_message='Endpoint was not created yet')
319318
self.sagemaker_session.delete_endpoint(self.latest_training_job.name)
320319

320+
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
321+
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
322+
max_payload=None, tags=None):
323+
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
324+
SageMaker Session and base job name used by the Estimator.
325+
326+
Args:
327+
instance_count (int): Number of EC2 instances to use.
328+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
329+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
330+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
331+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
332+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
333+
a default bucket.
334+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
335+
accept (str): The content type accepted by the endpoint deployed during the transform job.
336+
env (dict): Environment variables to be set for use during the transform job (default: None).
337+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
338+
each individual transform container at one time.
339+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
340+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
341+
the training job are used for the transform job.
342+
"""
343+
self._ensure_latest_training_job()
344+
345+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name)
346+
tags = tags or self.tags
347+
348+
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
349+
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
350+
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
351+
env=env, tags=tags, base_transform_job_name=self.base_job_name,
352+
sagemaker_session=self.sagemaker_session)
353+
321354
@property
322355
def training_job_analytics(self):
323356
"""Return a ``TrainingJobAnalytics`` object for the current training job.
@@ -326,6 +359,10 @@ def training_job_analytics(self):
326359
raise ValueError('Estimator is not associated with a TrainingJob')
327360
return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session)
328361

362+
def _ensure_latest_training_job(self, error_message='Estimator is not associated with a training job'):
363+
if self.latest_training_job is None:
364+
raise ValueError(error_message)
365+
329366

330367
class _TrainingJob(_Job):
331368
def __init__(self, sagemaker_session, training_job_name):
@@ -698,6 +735,50 @@ def _update_init_params(cls, hp, tf_arguments):
698735
updated_params[argument] = value
699736
return updated_params
700737

738+
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
739+
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
740+
max_payload=None, tags=None, model_server_workers=None):
741+
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
742+
SageMaker Session and base job name used by the Estimator.
743+
744+
Args:
745+
instance_count (int): Number of EC2 instances to use.
746+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
747+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
748+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
749+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
750+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
751+
a default bucket.
752+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
753+
accept (str): The content type accepted by the endpoint deployed during the transform job.
754+
env (dict): Environment variables to be set for use during the transform job (default: None).
755+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
756+
each individual transform container at one time.
757+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
758+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
759+
the training job are used for the transform job.
760+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
761+
If None, server will use one worker per vCPU.
762+
"""
763+
self._ensure_latest_training_job()
764+
765+
model = self.create_model(model_server_workers=model_server_workers)
766+
767+
container_def = model.prepare_container_def(instance_type)
768+
model_name = model.name or name_from_image(container_def['Image'])
769+
self.sagemaker_session.create_model(model_name, self.role, container_def)
770+
771+
transform_env = model.env.copy()
772+
if env is not None:
773+
transform_env.update(env)
774+
775+
tags = tags or self.tags
776+
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
777+
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
778+
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
779+
env=transform_env, tags=tags, base_transform_job_name=self.base_job_name,
780+
sagemaker_session=self.sagemaker_session)
781+
701782

702783
def _s3_uri_prefix(channel_name, s3_data):
703784
if isinstance(s3_data, s3_input):

src/sagemaker/job.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,14 @@ def _format_string_uri_input(input):
104104
elif input.startswith('file://'):
105105
return file_input(input)
106106
else:
107-
raise ValueError(
108-
'Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
109-
'"file://"')
107+
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
108+
'"file://"')
110109
elif isinstance(input, s3_input):
111110
return input
112111
elif isinstance(input, file_input):
113112
return input
114113
else:
115-
raise ValueError(
116-
'Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(
117-
input))
114+
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input))
118115

119116
@staticmethod
120117
def _format_record_set_list_input(inputs):

src/sagemaker/session.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,52 @@ def stop_tuning_job(self, name):
371371
LOGGER.error('Error occurred while attempting to stop tuning job: {}. Please try again.'.format(name))
372372
raise
373373

374+
def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env,
375+
input_config, output_config, resource_config, tags):
376+
"""Create an Amazon SageMaker transform job.
377+
378+
Args:
379+
job_name (str): Name of the transform job being created.
380+
model_name (str): Name of the SageMaker model being used for the transform job.
381+
strategy (str): The strategy used to decide how to batch records in a single request.
382+
Possible values are 'MULTI_RECORD' and 'SINGLE_RECORD'.
383+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
384+
each individual transform container at one time.
385+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
386+
env (dict): Environment variables to be set for use during the transform job.
387+
input_config (dict): A dictionary describing the input data (and its location) for the job.
388+
output_config (dict): A dictionary describing the output location for the job.
389+
resource_config (dict): A dictionary describing the resources to complete the job.
390+
tags (list[dict]): List of tags for labeling a training job. For more, see
391+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
392+
"""
393+
transform_request = {
394+
'TransformJobName': job_name,
395+
'ModelName': model_name,
396+
'TransformInput': input_config,
397+
'TransformOutput': output_config,
398+
'TransformResources': resource_config,
399+
}
400+
401+
if strategy is not None:
402+
transform_request['BatchStrategy'] = strategy
403+
404+
if max_concurrent_transforms is not None:
405+
transform_request['MaxConcurrentTransforms'] = max_concurrent_transforms
406+
407+
if max_payload is not None:
408+
transform_request['MaxPayloadInMB'] = max_payload
409+
410+
if env is not None:
411+
transform_request['Environment'] = env
412+
413+
if tags is not None:
414+
transform_request['Tags'] = tags
415+
416+
LOGGER.info('Creating transform job with name: {}'.format(job_name))
417+
LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4)))
418+
self.sagemaker_client.create_transform_job(**transform_request)
419+
374420
def create_model(self, name, role, primary_container):
375421
"""Create an Amazon SageMaker ``Model``.
376422
@@ -522,6 +568,23 @@ def wait_for_tuning_job(self, job, poll=5):
522568
self._check_job_status(job, desc, 'HyperParameterTuningJobStatus')
523569
return desc
524570

571+
def wait_for_transform_job(self, job, poll=5):
572+
"""Wait for an Amazon SageMaker transform job to complete.
573+
574+
Args:
575+
job (str): Name of the transform job to wait for.
576+
poll (int): Polling interval in seconds (default: 5).
577+
578+
Returns:
579+
(dict): Return value from the ``DescribeTransformJob`` API.
580+
581+
Raises:
582+
ValueError: If the transform job fails.
583+
"""
584+
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
585+
self._check_job_status(job, desc, 'TransformJobStatus')
586+
return desc
587+
525588
def _check_job_status(self, job, desc, status_key_name):
526589
"""Check to see if the job completed successfully and, if not, construct and
527590
raise a ValueError.
@@ -898,7 +961,7 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
898961
compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode.
899962
content_type (str): MIME type of the input data (default: None).
900963
record_wrapping (str): Valid values: 'RecordIO' (default: None).
901-
s3_data_type (str): Value values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
964+
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
902965
a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will
903966
be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing
904967
each s3 object to train on. The Manifest file format is described in the SageMaker API documentation:
@@ -982,6 +1045,29 @@ def _tuning_job_status(sagemaker_client, job_name):
9821045
return desc
9831046

9841047

1048+
def _transform_job_status(sagemaker_client, job_name):
1049+
transform_job_status_codes = {
1050+
'Completed': '!',
1051+
'InProgress': '.',
1052+
'Failed': '*',
1053+
'Stopped': 's',
1054+
'Stopping': '_'
1055+
}
1056+
in_progress_statuses = ['InProgress', 'Stopping']
1057+
1058+
desc = sagemaker_client.describe_transform_job(TransformJobName=job_name)
1059+
status = desc['TransformJobStatus']
1060+
1061+
print(transform_job_status_codes.get(status, '?'), end='')
1062+
sys.stdout.flush()
1063+
1064+
if status in in_progress_statuses:
1065+
return None
1066+
1067+
print('')
1068+
return desc
1069+
1070+
9851071
def _deploy_done(sagemaker_client, endpoint_name):
9861072
hosting_status_codes = {
9871073
"OutOfService": "x",

0 commit comments

Comments
 (0)