Skip to content

Improved documentation and tests about input and output functions #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 18, 2017
17 changes: 10 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1365,11 +1365,11 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can

import numpy as np

def input_fn(data, content_type):
"""An input_fn that loads a pickled numpy array"""
def input_fn(serialized_input, content_type):
"""An input_fn that loads a pickled object"""
if request_content_type == "application/python-pickle":
array = np.load(StringIO(request_body))
return array.reshape(model.data_shpaes[0])
deserialized_input = pickle.loads(serialized_input)
return deserialized_input
else:
# Handle other content-types here or raise an Exception
# if the content type is not supported.
Expand All @@ -1384,15 +1384,18 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can

import numpy as np

def output_fn(data, accepts):
"""An output_fn that dumps a pickled numpy as response"""
def output_fn(prediction_result, accepts):
"""An output_fn that dumps a pickled object as response"""
if request_content_type == "application/python-pickle":
return np.dumps(data)
return np.dumps(prediction_result)
else:
# Handle other content-types here or raise an Exception
# if the content type is not supported.
pass

A example with ``input_fn`` and ``output_fn`` above can be found in
`here <https://github.com/aws/sagemaker-python-sdk/blob/master/tests/data/cifar_10/source/resnet_cifar_10.py#L143>`_.

SageMaker TensorFlow Docker containers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
22 changes: 13 additions & 9 deletions tests/data/cifar_10/source/resnet_cifar_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division
from __future__ import print_function

import pickle

import resnet_model
import tensorflow as tf

Expand Down Expand Up @@ -106,21 +108,19 @@ def model_fn(features, labels, mode, params):


def serving_input_fn(hyperpameters):
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=(32, 32, 3))}
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)


def train_input_fn(training_dir, hyperpameters):
return input_fn(tf.estimator.ModeKeys.TRAIN,
batch_size=BATCH_SIZE, data_dir=training_dir)
def train_input_fn(training_dir, hyperparameters):
return _generate_synthetic_data(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE)


def eval_input_fn(training_dir, hyperpameters):
return input_fn(tf.estimator.ModeKeys.EVAL,
batch_size=BATCH_SIZE, data_dir=training_dir)
def eval_input_fn(training_dir, hyperparameters):
return _generate_synthetic_data(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE)


def input_fn(mode, batch_size, data_dir):
def _generate_synthetic_data(mode, batch_size):
input_shape = [batch_size, HEIGHT, WIDTH, DEPTH]
images = tf.truncated_normal(
input_shape,
Expand All @@ -138,3 +138,7 @@ def input_fn(mode, batch_size, data_dir):
labels = tf.contrib.framework.local_variable(labels, name='labels')

return {INPUT_TENSOR_NAME: images}, labels


def input_fn(serialized_data, content_type):
return pickle.loads(serialized_data)
21 changes: 20 additions & 1 deletion tests/integ/test_tf_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pickle

import boto3
import numpy as np
import os
import pytest

Expand All @@ -19,12 +22,22 @@
from tests.integ import DATA_DIR, REGION
from tests.integ.timeout import timeout_and_delete_endpoint, timeout

PICKLE_CONTENT_TYPE = 'application/python-pickle'


@pytest.fixture(scope='module')
def sagemaker_session():
return Session(boto_session=boto3.Session(region_name=REGION))


class PickleSerializer(object):
def __init__(self):
self.content_type = PICKLE_CONTENT_TYPE

def __call__(self, data):
return pickle.dumps(data, protocol=2)


def test_cifar(sagemaker_session):
with timeout(minutes=15):
script_path = os.path.join(DATA_DIR, 'cifar_10', 'source')
Expand All @@ -42,4 +55,10 @@ def test_cifar(sagemaker_session):
print('job succeeded: {}'.format(estimator.latest_training_job.name))

with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge')
predictor.serializer = PickleSerializer()
predictor.content_type = PICKLE_CONTENT_TYPE

data = np.random.randn(32, 32, 3)
predict_response = predictor.predict(data)
assert len(predict_response['outputs']['probabilities']['floatVal']) == 10