diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9657bb9ac7..625381a942 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,12 @@ CHANGELOG ========= +1.7.0 +===== + +* feature: Transformer: add support for batch transform jobs +* feature: Documentation: add instructions for using Pipe Mode with TensorFlow + 1.6.1 ===== diff --git a/README.rst b/README.rst index e6efb11673..92fb8ea5a0 100644 --- a/README.rst +++ b/README.rst @@ -32,7 +32,8 @@ Table of Contents 7. `AWS SageMaker Estimators <#aws-sagemaker-estimators>`__ 8. `BYO Docker Containers with SageMaker Estimators <#byo-docker-containers-with-sagemaker-estimators>`__ 9. `SageMaker Automatic Model Tuning <#sagemaker-automatic-model-tuning>`__ -10. `BYO Model <#byo-model>`__ +10. `SageMaker Batch Transform <#sagemaker-batch-transform>`__ +11. `BYO Model <#byo-model>`__ Getting SageMaker Python SDK @@ -50,7 +51,7 @@ You can install from source by cloning this repository and issuing a pip install git clone https://github.com/aws/sagemaker-python-sdk.git python setup.py sdist - pip install dist/sagemaker-1.6.1.tar.gz + pip install dist/sagemaker-1.7.0.tar.gz Supported Python versions ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -375,6 +376,39 @@ For more detailed explanations of the classes that this library provides for aut - `API docs for analytics classes `__ +SageMaker Batch Transform +------------------------- + +Once you have a trained model, you can use Amazon SageMaker Batch Transform to perform inferences with the model. +Batch Transform manages all compute resources necessary, including launching instances to deploy endpoints and deleting them afterward. +You can read more about SageMaker Batch Transform in the `AWS documentation `__. + +If you have trained the model using a SageMaker Python SDK Estimator, you can simply invoke ``transformer()`` to create a ``Transformer`` for the training job: + +.. code:: python + + transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') + +Alternatively, if you already have a SageMaker Model, you can instantiate a ``Transformer`` directly with its constructor: + +.. code:: python + + transformer = Transformer(model_name='my-previously-trained-model', + instance_count=1, + instance_type='ml.m4.xlarge') + +For a full list of the possible options to configure through either of these methods, please refer to the API docs for `Estimator `__ or `Transformer `__. + +Once you've created a ``Transformer`` object, you can invoke ``transform()`` to being a batch transform job with the S3 location of your data. +You can also specify other attributes about your data, such as the content type. + +.. code:: python + + transformer.transform('s3://my-bucket/batch-transform-input') + +For more details about what can be specified here, please refer to the `API docs `__. + + FAQ --- @@ -422,7 +456,7 @@ Example code using the TensorFlow predictor: BYO Model ------------------------------------------------ +--------- You can also create an endpoint from an existing model rather than training one - i.e. bring your own model. First, package the files for the trained model into a ``.tar.gz`` file, and upload the archive to S3. diff --git a/doc/index.rst b/doc/index.rst index 3793797a34..d8f04696bd 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -16,6 +16,7 @@ The SageMaker Python SDK consists of a few primary interfaces: estimators tuner predictors + transformer session model analytics diff --git a/doc/transformer.rst b/doc/transformer.rst new file mode 100644 index 0000000000..1c49ac9945 --- /dev/null +++ b/doc/transformer.rst @@ -0,0 +1,7 @@ +Transformer +----------- + +.. autoclass:: sagemaker.transformer.Transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/setup.py b/setup.py index 0d21678731..492113d272 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def read(fname): setup(name="sagemaker", - version="1.6.1", + version="1.7.0", description="Open source library for training and deploying models on Amazon SageMaker.", packages=find_packages('src'), package_dir={'': 'src'}, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fba27ed661..182c00796b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -30,7 +30,8 @@ from sagemaker.predictor import RealTimePredictor from sagemaker.session import Session from sagemaker.session import s3_input -from sagemaker.utils import base_name_from_image, name_from_base, get_config_value +from sagemaker.transformer import Transformer +from sagemaker.utils import base_name_from_image, name_from_base, name_from_image, get_config_value class EstimatorBase(with_metaclass(ABCMeta, object)): @@ -253,8 +254,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kw sagemaker.predictor.RealTimePredictor: A predictor that provides a ``predict()`` method, which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences. """ - if not self.latest_training_job: - raise RuntimeError('Estimator has not been fit yet.') + self._ensure_latest_training_job() endpoint_name = endpoint_name or self.latest_training_job.name self.deploy_instance_type = instance_type return self.create_model(**kwargs).deploy( @@ -314,10 +314,43 @@ def delete_endpoint(self): Raises: ValueError: If the endpoint does not exist. """ - if self.latest_training_job is None: - raise ValueError('Endpoint was not created yet') + self._ensure_latest_training_job(error_message='Endpoint was not created yet') self.sagemaker_session.delete_endpoint(self.latest_training_job.name) + def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, + output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, + max_payload=None, tags=None): + """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the + SageMaker Session and base job name used by the Estimator. + + Args: + instance_count (int): Number of EC2 instances to use. + instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. + strategy (str): The strategy used to decide how to batch records in a single request (default: None). + Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'. + assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. + output_path (str): S3 location for saving the transform result. If not specified, results are stored to + a default bucket. + output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None). + accept (str): The content type accepted by the endpoint deployed during the transform job. + env (dict): Environment variables to be set for use during the transform job (default: None). + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to + each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. + tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for + the training job are used for the transform job. + """ + self._ensure_latest_training_job() + + model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name) + tags = tags or self.tags + + return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, + output_path=output_path, output_kms_key=output_kms_key, accept=accept, + max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, + env=env, tags=tags, base_transform_job_name=self.base_job_name, + sagemaker_session=self.sagemaker_session) + @property def training_job_analytics(self): """Return a ``TrainingJobAnalytics`` object for the current training job. @@ -326,6 +359,10 @@ def training_job_analytics(self): raise ValueError('Estimator is not associated with a TrainingJob') return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session) + def _ensure_latest_training_job(self, error_message='Estimator is not associated with a training job'): + if self.latest_training_job is None: + raise ValueError(error_message) + class _TrainingJob(_Job): def __init__(self, sagemaker_session, training_job_name): @@ -698,6 +735,50 @@ def _update_init_params(cls, hp, tf_arguments): updated_params[argument] = value return updated_params + def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, + output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, + max_payload=None, tags=None, model_server_workers=None): + """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the + SageMaker Session and base job name used by the Estimator. + + Args: + instance_count (int): Number of EC2 instances to use. + instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. + strategy (str): The strategy used to decide how to batch records in a single request (default: None). + Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'. + assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. + output_path (str): S3 location for saving the transform result. If not specified, results are stored to + a default bucket. + output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None). + accept (str): The content type accepted by the endpoint deployed during the transform job. + env (dict): Environment variables to be set for use during the transform job (default: None). + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to + each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. + tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for + the training job are used for the transform job. + model_server_workers (int): Optional. The number of worker processes used by the inference server. + If None, server will use one worker per vCPU. + """ + self._ensure_latest_training_job() + + model = self.create_model(model_server_workers=model_server_workers) + + container_def = model.prepare_container_def(instance_type) + model_name = model.name or name_from_image(container_def['Image']) + self.sagemaker_session.create_model(model_name, self.role, container_def) + + transform_env = model.env.copy() + if env is not None: + transform_env.update(env) + + tags = tags or self.tags + return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, + output_path=output_path, output_kms_key=output_kms_key, accept=accept, + max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, + env=transform_env, tags=tags, base_transform_job_name=self.base_job_name, + sagemaker_session=self.sagemaker_session) + def _s3_uri_prefix(channel_name, s3_data): if isinstance(s3_data, s3_input): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 0787646aa6..92e2314c4c 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -104,17 +104,14 @@ def _format_string_uri_input(input): elif input.startswith('file://'): return file_input(input) else: - raise ValueError( - 'Training input data must be a valid S3 or FILE URI: must start with "s3://" or ' - '"file://"') + raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or ' + '"file://"') elif isinstance(input, s3_input): return input elif isinstance(input, file_input): return input else: - raise ValueError( - 'Cannot format input {}. Expecting one of str, s3_input, or file_input'.format( - input)) + raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input)) @staticmethod def _format_record_set_list_input(inputs): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index fc01dfc5b2..d08c4fc307 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -371,6 +371,52 @@ def stop_tuning_job(self, name): LOGGER.error('Error occurred while attempting to stop tuning job: {}. Please try again.'.format(name)) raise + def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env, + input_config, output_config, resource_config, tags): + """Create an Amazon SageMaker transform job. + + Args: + job_name (str): Name of the transform job being created. + model_name (str): Name of the SageMaker model being used for the transform job. + strategy (str): The strategy used to decide how to batch records in a single request. + Possible values are 'MULTI_RECORD' and 'SINGLE_RECORD'. + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to + each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. + env (dict): Environment variables to be set for use during the transform job. + input_config (dict): A dictionary describing the input data (and its location) for the job. + output_config (dict): A dictionary describing the output location for the job. + resource_config (dict): A dictionary describing the resources to complete the job. + tags (list[dict]): List of tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + """ + transform_request = { + 'TransformJobName': job_name, + 'ModelName': model_name, + 'TransformInput': input_config, + 'TransformOutput': output_config, + 'TransformResources': resource_config, + } + + if strategy is not None: + transform_request['BatchStrategy'] = strategy + + if max_concurrent_transforms is not None: + transform_request['MaxConcurrentTransforms'] = max_concurrent_transforms + + if max_payload is not None: + transform_request['MaxPayloadInMB'] = max_payload + + if env is not None: + transform_request['Environment'] = env + + if tags is not None: + transform_request['Tags'] = tags + + LOGGER.info('Creating transform job with name: {}'.format(job_name)) + LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4))) + self.sagemaker_client.create_transform_job(**transform_request) + def create_model(self, name, role, primary_container): """Create an Amazon SageMaker ``Model``. @@ -522,6 +568,23 @@ def wait_for_tuning_job(self, job, poll=5): self._check_job_status(job, desc, 'HyperParameterTuningJobStatus') return desc + def wait_for_transform_job(self, job, poll=5): + """Wait for an Amazon SageMaker transform job to complete. + + Args: + job (str): Name of the transform job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeTransformJob`` API. + + Raises: + ValueError: If the transform job fails. + """ + desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll) + self._check_job_status(job, desc, 'TransformJobStatus') + return desc + def _check_job_status(self, job, desc, status_key_name): """Check to see if the job completed successfully and, if not, construct and raise a ValueError. @@ -898,7 +961,7 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode. content_type (str): MIME type of the input data (default: None). record_wrapping (str): Valid values: 'RecordIO' (default: None). - s3_data_type (str): Value values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines + s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing each s3 object to train on. The Manifest file format is described in the SageMaker API documentation: @@ -982,6 +1045,29 @@ def _tuning_job_status(sagemaker_client, job_name): return desc +def _transform_job_status(sagemaker_client, job_name): + transform_job_status_codes = { + 'Completed': '!', + 'InProgress': '.', + 'Failed': '*', + 'Stopped': 's', + 'Stopping': '_' + } + in_progress_statuses = ['InProgress', 'Stopping'] + + desc = sagemaker_client.describe_transform_job(TransformJobName=job_name) + status = desc['TransformJobStatus'] + + print(transform_job_status_codes.get(status, '?'), end='') + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + print('') + return desc + + def _deploy_done(sagemaker_client, endpoint_name): hosting_status_codes = { "OutOfService": "x", diff --git a/src/sagemaker/tensorflow/README.rst b/src/sagemaker/tensorflow/README.rst index 3c6154838e..051a05ff8f 100644 --- a/src/sagemaker/tensorflow/README.rst +++ b/src/sagemaker/tensorflow/README.rst @@ -760,6 +760,73 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can A example with ``input_fn`` and ``output_fn`` above can be found in `here `_. +Training with Pipe Mode using PipeModeDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Amazon SageMaker allows users to create training jobs using Pipe input mode. +With Pipe input mode, your dataset is streamed directly to your training instances instead of being downloaded first. +This means that your training jobs start sooner, finish quicker, and need less disk space. + +SageMaker TensorFlow provides an implementation of ``tf.data.Dataset`` that makes it easy to take advantage of Pipe +input mode in SageMaker. You can replace your ``tf.data.Dataset`` with a ``sagemaker_tensorflow.PipeModeDataset`` to +read TFRecords as they are streamed to your training instances. + +In your ``entry_point`` script, you can use ``PipeModeDataset`` like a ``Dataset``. In this example, we create a +``PipeModeDataset`` to read TFRecords from the 'training' channel: + + +.. code:: python + + from sagemaker_tensorflow import PipeModeDataset + + ds = PipeModeDataset(channel='training', record_format='TFRecord') + + features = { + 'data': tf.FixedLenFeature([], tf.string), + 'labels': tf.FixedLenFeature([], tf.int64), + } + + def parse(record): + parsed = tf.parse_single_example(record, features) + return ({ + 'data': tf.decode_raw(parsed['data'], tf.float64) + }, parsed['labels']) + + ds = PipeModeDataset(channel='training', record_format='TFRecord') + num_epochs = 20 + ds = ds.repeat(num_epochs) + ds = ds.prefetch(10) + ds = ds.map(parse, num_parallel_calls=10) + ds = ds.batch(64) + + +To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your TensorFlow Estimator: + + +.. code:: python + + from sagemaker.tensorflow import TensorFlow + + tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole', + training_steps=10000, evaluation_steps=100, + train_instance_count=1, train_instance_type='ml.p2.xlarge', + input_mode='Pipe') + + tf_estimator.fit('s3://bucket/path/to/training/data') + + +If your TFRecords are compressed, you can train on Gzipped TF Records by passing in ``compression='Gzip'`` to the call to +``fit()``, and SageMaker will automatically unzip the records as data is streamed to your training instances: + +.. code:: python + + tf_estimator.fit('s3://bucket/path/to/training/data', compression='Gzip') + + +You can learn more about ``PipeModeDataset`` in the sagemaker-tensorflow-extensions repository: https://github.com/aws/sagemaker-tensorflow-extensions + + + SageMaker TensorFlow Docker containers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py new file mode 100644 index 0000000000..2d707eaa54 --- /dev/null +++ b/src/sagemaker/transformer.py @@ -0,0 +1,245 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.job import _Job +from sagemaker.session import Session +from sagemaker.utils import base_name_from_image, name_from_base + + +class Transformer(object): + """A class for handling creating and interacting with Amazon SageMaker transform jobs. + """ + + def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, + output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None, + env=None, base_transform_job_name=None, sagemaker_session=None): + """Initialize a ``Transformer``. + + Args: + model_name (str): Name of the SageMaker model being used for the transform job. + instance_count (int): Number of EC2 instances to use. + instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. + strategy (str): The strategy used to decide how to batch records in a single request (default: None). + Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'. + assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. + output_path (str): S3 location for saving the transform result. If not specified, results are stored to + a default bucket. + output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None). + accept (str): The content type accepted by the endpoint deployed during the transform job. + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to + each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. + env (dict): Environment variables to be set for use during the transform job (default: None). + tags (list[dict]): List of tags for labeling a transform job (default: None). For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + base_transform_job_name (str): Prefix for the transform job when the + :meth:`~sagemaker.transformer.Transformer.transform` method launches. If not specified, a default prefix + will be generated based on the training image name that was used to train the model associated with + the transform job. + sagemaker_session (sagemaker.session.Session): Session object which manages interactions with + Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one + using the default AWS configuration chain. + """ + self.model_name = model_name + self.strategy = strategy + self.env = env + + self.output_path = output_path + self.output_kms_key = output_kms_key + self.accept = accept + self.assemble_with = assemble_with + + self.instance_count = instance_count + self.instance_type = instance_type + + self.max_concurrent_transforms = max_concurrent_transforms + self.max_payload = max_payload + self.tags = tags + + self.base_transform_job_name = base_transform_job_name + self._current_job_name = None + self.latest_transform_job = None + + self.sagemaker_session = sagemaker_session or Session() + + def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, + job_name=None): + """Start a new transform job. + + Args: + data (str): Input data location in S3. + data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: + + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as + inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as + an input for the transform job. + + content_type (str): MIME type of the input data (default: None). + compression (str): Compression type of the input data, if compressed (default: None). + Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object (default: 'None'). + Valid values: 'None', 'Line', and 'RecordIO'. + job_name (str): job name (default: None). If not specified, one will be generated. + """ + if not data.startswith('s3://'): + raise ValueError('Invalid S3 URI: {}'.format(data)) + + if job_name is not None: + self._current_job_name = job_name + else: + base_name = self.base_transform_job_name or base_name_from_image(self._retrieve_image_name()) + self._current_job_name = name_from_base(base_name) + + if self.output_path is None: + self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name) + + self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type, + split_type) + + def _retrieve_image_name(self): + model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name) + return model_desc['PrimaryContainer']['Image'] + + def wait(self): + self._ensure_last_transform_job() + self.latest_transform_job.wait() + + def _ensure_last_transform_job(self): + if self.latest_transform_job is None: + raise ValueError('No transform job available') + + @classmethod + def attach(cls, transform_job_name, sagemaker_session=None): + """Attach an existing transform job to a new Transformer instance + + Args: + transform_job_name (str): Name for the transform job to be attached. + sagemaker_session (sagemaker.session.Session): Session object which manages interactions with + Amazon SageMaker APIs and any other AWS services needed. If not specified, one will be created + using the default AWS configuration chain. + + Returns: + sagemaker.transformer.Transformer: The Transformer instance with the specified transform job attached. + + """ + sagemaker_session = sagemaker_session or Session() + + job_details = sagemaker_session.sagemaker_client.describe_transform_job(TransformJobName=transform_job_name) + init_params = cls._prepare_init_params_from_job_description(job_details) + transformer = cls(sagemaker_session=sagemaker_session, **init_params) + transformer.latest_transform_job = _TransformJob(sagemaker_session=sagemaker_session, + transform_job_name=init_params['base_transform_job_name']) + + return transformer + + @classmethod + def _prepare_init_params_from_job_description(cls, job_details): + """Convert the transform job description to init params that can be handled by the class constructor + + Args: + job_details (dict): the returned job details from a describe_transform_job API call. + + Returns: + dict: The transformed init_params + """ + init_params = dict() + + init_params['model_name'] = job_details['ModelName'] + init_params['instance_count'] = job_details['TransformResources']['InstanceCount'] + init_params['instance_type'] = job_details['TransformResources']['InstanceType'] + init_params['strategy'] = job_details.get('BatchStrategy') + init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith') + init_params['output_path'] = job_details['TransformOutput']['S3OutputPath'] + init_params['output_kms_key'] = job_details['TransformOutput'].get('KmsKeyId') + init_params['accept'] = job_details['TransformOutput'].get('Accept') + init_params['max_concurrent_transforms'] = job_details.get('MaxConcurrentTransforms') + init_params['max_payload'] = job_details.get('MaxPayloadInMB') + init_params['base_transform_job_name'] = job_details['TransformJobName'] + + return init_params + + +class _TransformJob(_Job): + def __init__(self, sagemaker_session, transform_job_name): + super(_TransformJob, self).__init__(sagemaker_session, transform_job_name) + + @classmethod + def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type): + config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer) + + transformer.sagemaker_session.transform(job_name=transformer._current_job_name, + model_name=transformer.model_name, strategy=transformer.strategy, + max_concurrent_transforms=transformer.max_concurrent_transforms, + max_payload=transformer.max_payload, env=transformer.env, + input_config=config['input_config'], + output_config=config['output_config'], + resource_config=config['resource_config'], tags=transformer.tags) + + return cls(transformer.sagemaker_session, transformer._current_job_name) + + def wait(self): + self.sagemaker_session.wait_for_transform_job(self.job_name) + + @staticmethod + def _load_config(data, data_type, content_type, compression_type, split_type, transformer): + input_config = _TransformJob._format_inputs_to_input_config(data, data_type, content_type, + compression_type, split_type) + + output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key, + transformer.assemble_with, transformer.accept) + + resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type) + + return {'input_config': input_config, + 'output_config': output_config, + 'resource_config': resource_config} + + @staticmethod + def _format_inputs_to_input_config(data, data_type, content_type, compression_type, split_type): + config = { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': data_type, + 'S3Uri': data, + } + } + } + + if content_type is not None: + config['ContentType'] = content_type + + if compression_type is not None: + config['CompressionType'] = compression_type + + if split_type is not None: + config['SplitType'] = split_type + + return config + + @staticmethod + def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept): + config = super(_TransformJob, _TransformJob)._prepare_output_config(s3_path, kms_key_id) + + if assemble_with is not None: + config['AssembleWith'] = assemble_with + + if accept is not None: + config['Accept'] = accept + + return config + + @staticmethod + def _prepare_resource_config(instance_count, instance_type): + return {'InstanceCount': instance_count, 'InstanceType': instance_type} diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index b39df5f29d..8831d35b76 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -21,9 +21,10 @@ from mock import Mock, patch from sagemaker.estimator import Estimator, Framework, _TrainingJob -from sagemaker.session import s3_input from sagemaker.model import FrameworkModel from sagemaker.predictor import RealTimePredictor +from sagemaker.session import s3_input +from sagemaker.transformer import Transformer MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -40,6 +41,8 @@ IMAGE_NAME = 'fakeimage' REGION = 'us-west-2' JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) +TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] +OUTPUT_PATH = 's3://bucket/prefix' COMMON_TRAIN_ARGS = { 'volume_size': 30, @@ -64,6 +67,18 @@ } } +MODEL_CONTAINER_DEF = { + 'Environment': { + 'SAGEMAKER_PROGRAM': ENTRY_POINT, + 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': REGION, + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false' + }, + 'Image': MODEL_IMAGE, + 'ModelDataUrl': MODEL_DATA, +} + class DummyFramework(Framework): __framework_name__ = 'dummy' @@ -71,7 +86,7 @@ class DummyFramework(Framework): def train_image(self): return IMAGE_NAME - def create_model(self): + def create_model(self, model_server_workers=None): return DummyFrameworkModel(self.sagemaker_session) @classmethod @@ -89,6 +104,9 @@ def __init__(self, sagemaker_session, **kwargs): def create_predictor(self, endpoint_name): return None + def prepare_container_def(self, instance_type): + return MODEL_CONTAINER_DEF + @pytest.fixture() def sagemaker_session(): @@ -364,10 +382,9 @@ def test_attach_framework_with_tuning(sagemaker_session): @patch('time.strftime', return_value=TIMESTAMP) def test_fit_verify_job_name(strftime, sagemaker_session): - tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True, tags=tags) + enable_cloudwatch_metrics=True, tags=TAGS) fw.fit(inputs=s3_input('s3://mybucket/train')) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] @@ -375,7 +392,7 @@ def test_fit_verify_job_name(strftime, sagemaker_session): assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics'] assert train_kwargs['image'] == IMAGE_NAME assert train_kwargs['input_mode'] == 'File' - assert train_kwargs['tags'] == tags + assert train_kwargs['tags'] == TAGS assert train_kwargs['job_name'] == JOB_NAME assert fw.latest_training_job.name == JOB_NAME @@ -431,6 +448,128 @@ def test_init_with_source_dir_s3(strftime, sagemaker_session): assert fw._hyperparameters == expected_hyperparameters +@patch('sagemaker.estimator.name_from_image', return_value=MODEL_IMAGE) +def test_framework_transformer_creation(name_from_image, sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) + + transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE) + + name_from_image.assert_called_with(MODEL_IMAGE) + sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF) + + assert isinstance(transformer, Transformer) + assert transformer.sagemaker_session == sagemaker_session + assert transformer.instance_count == INSTANCE_COUNT + assert transformer.instance_type == INSTANCE_TYPE + assert transformer.model_name == MODEL_IMAGE + assert transformer.tags is None + assert transformer.env == {} + + +@patch('sagemaker.estimator.name_from_image', return_value=MODEL_IMAGE) +def test_framework_transformer_creation_with_optional_params(name_from_image, sagemaker_session): + base_name = 'foo' + fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, + base_job_name=base_name) + fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) + + transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE) + + strategy = 'MultiRecord' + assemble_with = 'Line' + kms_key = 'key' + accept = 'text/csv' + max_concurrent_transforms = 1 + max_payload = 6 + env = {'FOO': 'BAR'} + + transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, + output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, + max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, + env=env, model_server_workers=1) + + assert transformer.strategy == strategy + assert transformer.assemble_with == assemble_with + assert transformer.output_path == OUTPUT_PATH + assert transformer.output_kms_key == kms_key + assert transformer.accept == accept + assert transformer.max_concurrent_transforms == max_concurrent_transforms + assert transformer.max_payload == max_payload + assert transformer.env == env + assert transformer.base_transform_job_name == base_name + assert transformer.tags == TAGS + + +def test_ensure_latest_training_job(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + fw.latest_training_job = Mock(name='training_job') + + fw._ensure_latest_training_job() + + +def test_ensure_latest_training_job_failure(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + + with pytest.raises(ValueError) as e: + fw._ensure_latest_training_job() + assert 'Estimator is not associated with a training job' in str(e) + + +def test_estimator_transformer_creation(sagemaker_session): + estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) + sagemaker_session.create_model_from_job.return_value = JOB_NAME + + transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE) + + sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME) + assert isinstance(transformer, Transformer) + assert transformer.sagemaker_session == sagemaker_session + assert transformer.instance_count == INSTANCE_COUNT + assert transformer.instance_type == INSTANCE_TYPE + assert transformer.model_name == JOB_NAME + assert transformer.tags is None + + +def test_estimator_transformer_creation_with_optional_params(sagemaker_session): + base_name = 'foo' + estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, + base_job_name=base_name) + estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) + sagemaker_session.create_model_from_job.return_value = JOB_NAME + + strategy = 'MultiRecord' + assemble_with = 'Line' + kms_key = 'key' + accept = 'text/csv' + max_concurrent_transforms = 1 + max_payload = 6 + env = {'FOO': 'BAR'} + + transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, + output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, + max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, + env=env) + + assert transformer.strategy == strategy + assert transformer.assemble_with == assemble_with + assert transformer.output_path == OUTPUT_PATH + assert transformer.output_kms_key == kms_key + assert transformer.accept == accept + assert transformer.max_concurrent_transforms == max_concurrent_transforms + assert transformer.max_payload == max_payload + assert transformer.env == env + assert transformer.base_transform_job_name == base_name + assert transformer.tags == TAGS + + # _TrainingJob 'utils' def test_start_new(sagemaker_session): training_job = _TrainingJob(sagemaker_session, JOB_NAME) @@ -438,7 +577,7 @@ def test_start_new(sagemaker_session): inputs = 's3://mybucket/train' estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - output_path='s3://bucket/prefix', sagemaker_session=sagemaker_session, + output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session, hyperparameters=hyperparameters) started_training_job = training_job.start_new(estimator, inputs) @@ -454,7 +593,7 @@ def test_start_new_not_local_mode_error(sagemaker_session): inputs = 'file://mybucket/train' estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - output_path='s3://bucket/prefix', sagemaker_session=sagemaker_session) + output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) with pytest.raises(ValueError) as error: training_job.start_new(estimator, inputs) assert 'File URIs are supported in local mode only. Please use a S3 URI instead.' == str(error) @@ -510,7 +649,7 @@ def test_unsupported_type_in_dict(): 'ChannelName': 'train' }], 'input_mode': 'File', - 'output_config': {'S3OutputPath': 's3://bucket/prefix'}, + 'output_config': {'S3OutputPath': OUTPUT_PATH}, 'resource_config': { 'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, @@ -527,7 +666,7 @@ def test_unsupported_type_in_dict(): def test_generic_to_fit_no_hps(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix', + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) e.fit({'train': 's3://bucket/training-prefix'}) @@ -544,7 +683,7 @@ def test_generic_to_fit_no_hps(sagemaker_session): def test_generic_to_fit_with_hps(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix', + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) e.set_hyperparameters(**HYPERPARAMS) @@ -563,7 +702,7 @@ def test_generic_to_fit_with_hps(sagemaker_session): def test_generic_to_deploy(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix', + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) e.set_hyperparameters(**HYPERPARAMS) @@ -621,7 +760,7 @@ def test_generic_training_job_analytics(sagemaker_session): } ) - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix', + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) with pytest.raises(ValueError) as err: # noqa: F841 @@ -661,6 +800,6 @@ def test_local_mode(session_class, local_session_class): @patch('sagemaker.estimator.LocalSession') def test_distributed_gpu_local_mode(LocalSession): with pytest.raises(RuntimeError): - Estimator(IMAGE_NAME, ROLE, 3, 'local_gpu', output_path='s3://bucket/prefix') + Estimator(IMAGE_NAME, ROLE, 3, 'local_gpu', output_path=OUTPUT_PATH) ################################################################################# diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8dec4e0cd1..00a6fd82be 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -22,7 +22,7 @@ from botocore.exceptions import ClientError -from sagemaker.session import _tuning_job_status +from sagemaker.session import _tuning_job_status, _transform_job_status REGION = 'us-west-2' @@ -162,6 +162,7 @@ def test_s3_input_all_arguments(): MAX_SIZE = 30 MAX_TIME = 3 * 60 * 60 JOB_NAME = 'jobname' +TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { 'OutputDataConfig': { @@ -304,18 +305,73 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): 'VolumeSizeInGB': MAX_SIZE} stop_cond = {'MaxRuntimeInSeconds': MAX_TIME} - hyperparameters = {'foo': 'bar'} - tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags) + hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_train_args['HyperParameters'] == hyperparameters - assert actual_train_args['Tags'] == tags + assert actual_train_args['Tags'] == TAGS + + +def test_transform_pack_to_request(sagemaker_session): + model_name = 'my-model' + + in_config = { + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'SplitType': 'None', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': S3_INPUT_URI, + }, + }, + } + + out_config = {'S3OutputPath': S3_OUTPUT} + + resource_config = { + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE, + } + + expected_args = { + 'TransformJobName': JOB_NAME, + 'ModelName': model_name, + 'TransformInput': in_config, + 'TransformOutput': out_config, + 'TransformResources': resource_config, + } + + sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None, + max_payload=None, env=None, input_config=in_config, output_config=out_config, + resource_config=resource_config, tags=None) + + _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] + assert actual_args == expected_args + + +def test_transform_pack_to_request_with_optional_params(sagemaker_session): + strategy = 'strategy' + max_concurrent_transforms = 1 + max_payload = 0 + env = {'FOO': 'BAR'} + + sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy, + max_concurrent_transforms=max_concurrent_transforms, + env=env, max_payload=max_payload, input_config={}, output_config={}, + resource_config={}, tags=TAGS) + + _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] + assert actual_args['BatchStrategy'] == strategy + assert actual_args['MaxConcurrentTransforms'] == max_concurrent_transforms + assert actual_args['MaxPayloadInMB'] == max_payload + assert actual_args['Environment'] == env + assert actual_args['Tags'] == TAGS @patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO) @@ -563,3 +619,40 @@ def test_tune_job_status_none(sagemaker_session): result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME) assert result is None + + +def test_wait_for_transform_job_completed(sagemaker_session): + transform_job_desc = {'TransformJobStatus': 'Completed'} + sagemaker_session.sagemaker_client.describe_transform_job = Mock( + name='describe_transform_job', return_value=transform_job_desc) + + assert sagemaker_session.wait_for_transform_job(JOB_NAME)['TransformJobStatus'] == 'Completed' + + +def test_wait_for_transform_job_in_progress(sagemaker_session): + transform_job_desc_in_progress = {'TransformJobStatus': 'InProgress'} + transform_job_desc_in_completed = {'TransformJobStatus': 'Completed'} + sagemaker_session.sagemaker_client.describe_transform_job = Mock( + name='describe_transform_job', side_effect=[transform_job_desc_in_progress, + transform_job_desc_in_completed]) + + assert sagemaker_session.wait_for_transform_job(JOB_NAME, 1)['TransformJobStatus'] == 'Completed' + assert 2 == sagemaker_session.sagemaker_client.describe_transform_job.call_count + + +def test_transform_job_status(sagemaker_session): + transform_job_desc = {'TransformJobStatus': 'Completed'} + sagemaker_session.sagemaker_client.describe_transform_job = Mock( + name='describe_transform_job', return_value=transform_job_desc) + + result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME) + assert result['TransformJobStatus'] == 'Completed' + + +def test_transform_job_status_none(sagemaker_session): + transform_job_desc = {'TransformJobStatus': 'InProgress'} + sagemaker_session.sagemaker_client.describe_transform_job = Mock( + name='describe_transform_job', return_value=transform_job_desc) + + result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME) + assert result is None diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py new file mode 100644 index 0000000000..006970fd39 --- /dev/null +++ b/tests/unit/test_transformer.py @@ -0,0 +1,303 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, patch + +from sagemaker.transformer import Transformer, _TransformJob + +MODEL_NAME = 'model' +IMAGE_NAME = 'image-for-model' +JOB_NAME = 'job' + +INSTANCE_COUNT = 1 +INSTANCE_TYPE = 'ml.m4.xlarge' + +S3_DATA_TYPE = 'S3Prefix' +S3_BUCKET = 'bucket' +DATA = 's3://{}/input-data'.format(S3_BUCKET) +OUTPUT_PATH = 's3://{}/output'.format(S3_BUCKET) + +TIMESTAMP = '2018-07-12' + +INIT_PARAMS = { + 'model_name': MODEL_NAME, + 'instance_count': INSTANCE_COUNT, + 'instance_type': INSTANCE_TYPE, + 'base_transform_job_name': JOB_NAME +} + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name='boto_session') + return Mock(name='sagemaker_session', boto_session=boto_mock) + + +@pytest.fixture() +def transformer(sagemaker_session): + return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, + output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) + + +@patch('sagemaker.transformer._TransformJob.start_new') +def test_transform_with_all_params(start_new_job, transformer): + content_type = 'text/csv' + compression = 'Gzip' + split = 'Line' + + transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split, + job_name=JOB_NAME) + + assert transformer._current_job_name == JOB_NAME + assert transformer.output_path == OUTPUT_PATH + start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, split) + + +@patch('sagemaker.transformer.name_from_base') +@patch('sagemaker.transformer._TransformJob.start_new') +def test_transform_with_base_job_name(start_new_job, name_from_base, transformer): + base_name = 'base-job-name' + full_name = '{}-{}'.format(base_name, TIMESTAMP) + + transformer.base_transform_job_name = base_name + name_from_base.return_value = full_name + + transformer.transform(DATA) + assert name_from_base.called_with(base_name) + assert transformer._current_job_name == full_name + + +@patch('sagemaker.transformer.Transformer._retrieve_image_name', return_value=IMAGE_NAME) +@patch('sagemaker.transformer.name_from_base') +@patch('sagemaker.transformer._TransformJob.start_new') +def test_transform_with_fully_generated_job_name(start_new_job, name_from_base, retrieve_image_name, transformer): + full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) + name_from_base.return_value = full_name + + transformer.transform(DATA) + + assert retrieve_image_name.called_once + assert name_from_base.called_with(IMAGE_NAME) + assert transformer._current_job_name == full_name + + +@patch('sagemaker.transformer._TransformJob.start_new') +def test_transform_with_generated_output_path(start_new_job, transformer, sagemaker_session): + transformer.output_path = None + sagemaker_session.default_bucket.return_value = S3_BUCKET + + transformer.transform(DATA, job_name=JOB_NAME) + assert transformer.output_path == 's3://{}/{}'.format(S3_BUCKET, JOB_NAME) + + +def test_transform_with_invalid_s3_uri(transformer): + with pytest.raises(ValueError) as e: + transformer.transform('not-an-s3-uri') + + assert 'Invalid S3 URI' in str(e) + + +def test_retrieve_image_name(sagemaker_session, transformer): + sage_mock = Mock(name='sagemaker_client') + sage_mock.describe_model.return_value = {'PrimaryContainer': {'Image': IMAGE_NAME}} + + sagemaker_session.sagemaker_client = sage_mock + + assert transformer._retrieve_image_name() == IMAGE_NAME + + +@patch('sagemaker.transformer.Transformer._ensure_last_transform_job') +def test_wait(ensure_last_transform_job, transformer): + transformer.latest_transform_job = Mock(name='latest_transform_job') + + transformer.wait() + + assert ensure_last_transform_job.called_once + assert transformer.latest_transform_job.wait.called_once + + +def test_ensure_last_transform_job_exists(transformer, sagemaker_session): + transformer.latest_transform_job = _TransformJob(sagemaker_session, 'some-transform-job') + transformer._ensure_last_transform_job() + + +def test_ensure_last_transform_job_none(transformer): + transformer.latest_transform_job = None + with pytest.raises(ValueError) as e: + transformer._ensure_last_transform_job() + + assert 'No transform job available' in str(e) + + +@patch('sagemaker.transformer.Transformer._prepare_init_params_from_job_description', return_value=INIT_PARAMS) +def test_attach(prepare_init_params, transformer, sagemaker_session): + sagemaker_session.sagemaker_client.describe_transform_job = Mock(name='describe_transform_job') + attached = Transformer.attach(JOB_NAME, sagemaker_session) + + assert prepare_init_params.called_once + assert attached.latest_transform_job.job_name == JOB_NAME + assert attached.model_name == MODEL_NAME + assert attached.instance_count == INSTANCE_COUNT + assert attached.instance_type == INSTANCE_TYPE + + +def test_prepare_init_params_from_job_description_missing_keys(transformer): + job_details = { + 'ModelName': MODEL_NAME, + 'TransformResources': { + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE + }, + 'TransformOutput': { + 'S3OutputPath': None + }, + 'TransformJobName': JOB_NAME + } + + init_params = transformer._prepare_init_params_from_job_description(job_details) + + assert init_params['model_name'] == MODEL_NAME + assert init_params['instance_count'] == INSTANCE_COUNT + assert init_params['instance_type'] == INSTANCE_TYPE + + +def test_prepare_init_params_from_job_description_all_keys(transformer): + job_details = { + 'ModelName': MODEL_NAME, + 'TransformResources': { + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE + }, + 'BatchStrategy': None, + 'TransformOutput': { + 'AssembleWith': None, + 'S3OutputPath': None, + 'KmsKeyId': None, + 'Accept': None + }, + 'MaxConcurrentTransforms': None, + 'MaxPayloadInMB': None, + 'TransformJobName': JOB_NAME + } + + init_params = transformer._prepare_init_params_from_job_description(job_details) + + assert init_params['model_name'] == MODEL_NAME + assert init_params['instance_count'] == INSTANCE_COUNT + assert init_params['instance_type'] == INSTANCE_TYPE + + +# _TransformJob tests + +def test_start_new(transformer, sagemaker_session): + transformer._current_job_name = JOB_NAME + + job = _TransformJob(sagemaker_session, JOB_NAME) + started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None) + + assert started_job.sagemaker_session == sagemaker_session + sagemaker_session.transform.assert_called_once() + + +def test_load_config(transformer): + expected_config = { + 'input_config': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': S3_DATA_TYPE, + 'S3Uri': DATA, + }, + }, + }, + 'output_config': { + 'S3OutputPath': OUTPUT_PATH, + }, + 'resource_config': { + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE, + }, + } + + actual_config = _TransformJob._load_config(DATA, S3_DATA_TYPE, None, None, None, transformer) + assert actual_config == expected_config + + +def test_format_inputs_to_input_config(): + expected_config = { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': S3_DATA_TYPE, + 'S3Uri': DATA, + }, + }, + } + + actual_config = _TransformJob._format_inputs_to_input_config(DATA, S3_DATA_TYPE, None, None, None) + assert actual_config == expected_config + + +def test_format_inputs_to_input_config_with_optional_params(): + compression = 'Gzip' + content_type = 'text/csv' + split = 'Line' + + expected_config = { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': S3_DATA_TYPE, + 'S3Uri': DATA, + }, + }, + 'CompressionType': compression, + 'ContentType': content_type, + 'SplitType': split, + } + + actual_config = _TransformJob._format_inputs_to_input_config(DATA, S3_DATA_TYPE, content_type, compression, split) + assert actual_config == expected_config + + +def test_prepare_output_config(): + config = _TransformJob._prepare_output_config(OUTPUT_PATH, None, None, None) + + assert config == {'S3OutputPath': OUTPUT_PATH} + + +def test_prepare_output_config_with_optional_params(): + kms_key = 'key' + assemble_with = 'Line' + accept = 'text/csv' + + expected_config = { + 'S3OutputPath': OUTPUT_PATH, + 'KmsKeyId': kms_key, + 'AssembleWith': assemble_with, + 'Accept': accept, + } + + actual_config = _TransformJob._prepare_output_config(OUTPUT_PATH, kms_key, assemble_with, accept) + assert actual_config == expected_config + + +def test_prepare_resource_config(): + config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE) + assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE} + + +def test_transform_job_wait(sagemaker_session): + job = _TransformJob(sagemaker_session, JOB_NAME) + job.wait() + + assert sagemaker_session.wait_for_transform_job.called_once