Skip to content

Improved documentation and tests about input and output functions #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 18, 2017
13 changes: 8 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1358,11 +1358,11 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can

import numpy as np

def input_fn(data, content_type):
"""An input_fn that loads a pickled numpy array"""
def input_fn(serialized_input, content_type):
"""An input_fn that loads a pickled object"""
if request_content_type == "application/python-pickle":
array = np.load(StringIO(request_body))
return array.reshape(model.data_shpaes[0])
deserialized_input = pickle.loads(serialized_input)
return deserialized_input
else:
# Handle other content-types here or raise an Exception
# if the content type is not supported.
Expand All @@ -1377,7 +1377,7 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can

import numpy as np

def output_fn(data, accepts):
def output_fn(prediction_result, accepts):
"""An output_fn that dumps a pickled numpy as response"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 'numpy' to 'object' for consistency with the above change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

if request_content_type == "application/python-pickle":
return np.dumps(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should data be changed to prediction_result here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that for consistency, look at the code block below here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant for line 1383 - sorry for the confusion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Expand All @@ -1386,6 +1386,9 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can
# if the content type is not supported.
pass

A example with the ``input_fn`` and ``output_fn`` above can be find in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two small changes:

  • remove 'the'
  • s/find in/found

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

`here <https://github.com/aws/sagemaker-python-sdk/blob/master/tests/data/cifar_10/source/resnet_cifar_10.py#L143>`_.

SageMaker TensorFlow Docker containers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
27 changes: 18 additions & 9 deletions tests/data/cifar_10/source/resnet_cifar_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division
from __future__ import print_function

import pickle

import resnet_model
import tensorflow as tf

Expand Down Expand Up @@ -106,21 +108,19 @@ def model_fn(features, labels, mode, params):


def serving_input_fn(hyperpameters):
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=(32, 32, 3))}
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)


def train_input_fn(training_dir, hyperpameters):
return input_fn(tf.estimator.ModeKeys.TRAIN,
batch_size=BATCH_SIZE, data_dir=training_dir)
def train_input_fn(training_dir, hyperparameters):
return _input_fn(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE)


def eval_input_fn(training_dir, hyperpameters):
return input_fn(tf.estimator.ModeKeys.EVAL,
batch_size=BATCH_SIZE, data_dir=training_dir)
def eval_input_fn(training_dir, hyperparameters):
return _input_fn(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE)


def input_fn(mode, batch_size, data_dir):
def _input_fn(mode, batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better if this method's name were a little more distinct from input_fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point.

input_shape = [batch_size, HEIGHT, WIDTH, DEPTH]
images = tf.truncated_normal(
input_shape,
Expand All @@ -138,3 +138,12 @@ def input_fn(mode, batch_size, data_dir):
labels = tf.contrib.framework.local_variable(labels, name='labels')

return {INPUT_TENSOR_NAME: images}, labels


def input_fn(serialized_data, content_type):
data = pickle.loads(serialized_data)
return data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could make this a one-line method: return pickle.loads(serialized_data)



def output_fn(data, accepts):
return pickle.dumps(data)
34 changes: 33 additions & 1 deletion tests/integ/test_tf_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pickle

import boto3
import numpy as np
import os
import pytest

Expand All @@ -19,12 +22,34 @@
from tests.integ import DATA_DIR, REGION
from tests.integ.timeout import timeout_and_delete_endpoint, timeout

PICKLE_CONTENT_TYPE = "application/python-pickle"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use single quotes instead of double quotes for consistency with the rest of the SDK



@pytest.fixture(scope='module')
def sagemaker_session():
return Session(boto_session=boto3.Session(region_name=REGION))


class PickleSerializer(object):
def __init__(self):
self.content_type = PICKLE_CONTENT_TYPE

def __call__(self, data):
return pickle.dumps(data)


class PickleDeserializer(object):
def __init__(self):
self.accept = PICKLE_CONTENT_TYPE

def __call__(self, stream, content_type):
try:
data = stream.read().decode()
return pickle.loads(data)
finally:
stream.close()


def test_cifar(sagemaker_session):
with timeout(minutes=15):
script_path = os.path.join(DATA_DIR, 'cifar_10', 'source')
Expand All @@ -42,4 +67,11 @@ def test_cifar(sagemaker_session):
print('job succeeded: {}'.format(estimator.latest_training_job.name))

with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge')
predictor.serializer = PickleSerializer()
predictor.deserializer = PickleDeserializer()

data = np.random.randn(32, 32, 3)
predict_response = predictor.predict(data)

assert len(predict_response.outputs['probabilities'].float_val) == 10