-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Improved documentation and tests about input and output functions #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
a466bf0
281d873
4fb4bf0
2f5f927
8de091b
96165dd
3af07a5
f022f2b
1da65a9
bf6cc68
1f65d6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1358,11 +1358,11 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can | |
|
||
import numpy as np | ||
|
||
def input_fn(data, content_type): | ||
"""An input_fn that loads a pickled numpy array""" | ||
def input_fn(serialized_input, content_type): | ||
"""An input_fn that loads a pickled object""" | ||
if request_content_type == "application/python-pickle": | ||
array = np.load(StringIO(request_body)) | ||
return array.reshape(model.data_shpaes[0]) | ||
deserialized_input = pickle.loads(serialized_input) | ||
return deserialized_input | ||
else: | ||
# Handle other content-types here or raise an Exception | ||
# if the content type is not supported. | ||
|
@@ -1377,7 +1377,7 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can | |
|
||
import numpy as np | ||
|
||
def output_fn(data, accepts): | ||
def output_fn(prediction_result, accepts): | ||
"""An output_fn that dumps a pickled numpy as response""" | ||
if request_content_type == "application/python-pickle": | ||
return np.dumps(data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did that for consistency, look at the code block below here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant for line 1383 - sorry for the confusion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
@@ -1386,6 +1386,9 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can | |
# if the content type is not supported. | ||
pass | ||
|
||
A example with the ``input_fn`` and ``output_fn`` above can be find in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. two small changes:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
`here <https://github.com/aws/sagemaker-python-sdk/blob/master/tests/data/cifar_10/source/resnet_cifar_10.py#L143>`_. | ||
|
||
SageMaker TensorFlow Docker containers | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import pickle | ||
|
||
import resnet_model | ||
import tensorflow as tf | ||
|
||
|
@@ -106,21 +108,19 @@ def model_fn(features, labels, mode, params): | |
|
||
|
||
def serving_input_fn(hyperpameters): | ||
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=(32, 32, 3))} | ||
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)() | ||
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])} | ||
return tf.estimator.export.ServingInputReceiver(inputs, inputs) | ||
|
||
|
||
def train_input_fn(training_dir, hyperpameters): | ||
return input_fn(tf.estimator.ModeKeys.TRAIN, | ||
batch_size=BATCH_SIZE, data_dir=training_dir) | ||
def train_input_fn(training_dir, hyperparameters): | ||
return _input_fn(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE) | ||
|
||
|
||
def eval_input_fn(training_dir, hyperpameters): | ||
return input_fn(tf.estimator.ModeKeys.EVAL, | ||
batch_size=BATCH_SIZE, data_dir=training_dir) | ||
def eval_input_fn(training_dir, hyperparameters): | ||
return _input_fn(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE) | ||
|
||
|
||
def input_fn(mode, batch_size, data_dir): | ||
def _input_fn(mode, batch_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it'd be better if this method's name were a little more distinct from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. |
||
input_shape = [batch_size, HEIGHT, WIDTH, DEPTH] | ||
images = tf.truncated_normal( | ||
input_shape, | ||
|
@@ -138,3 +138,12 @@ def input_fn(mode, batch_size, data_dir): | |
labels = tf.contrib.framework.local_variable(labels, name='labels') | ||
|
||
return {INPUT_TENSOR_NAME: images}, labels | ||
|
||
|
||
def input_fn(serialized_data, content_type): | ||
data = pickle.loads(serialized_data) | ||
return data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could make this a one-line method: |
||
|
||
|
||
def output_fn(data, accepts): | ||
return pickle.dumps(data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,10 @@ | |
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import pickle | ||
|
||
import boto3 | ||
import numpy as np | ||
import os | ||
import pytest | ||
|
||
|
@@ -19,12 +22,34 @@ | |
from tests.integ import DATA_DIR, REGION | ||
from tests.integ.timeout import timeout_and_delete_endpoint, timeout | ||
|
||
PICKLE_CONTENT_TYPE = "application/python-pickle" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use single quotes instead of double quotes for consistency with the rest of the SDK |
||
|
||
|
||
@pytest.fixture(scope='module') | ||
def sagemaker_session(): | ||
return Session(boto_session=boto3.Session(region_name=REGION)) | ||
|
||
|
||
class PickleSerializer(object): | ||
def __init__(self): | ||
self.content_type = PICKLE_CONTENT_TYPE | ||
|
||
def __call__(self, data): | ||
return pickle.dumps(data) | ||
|
||
|
||
class PickleDeserializer(object): | ||
def __init__(self): | ||
self.accept = PICKLE_CONTENT_TYPE | ||
|
||
def __call__(self, stream, content_type): | ||
try: | ||
data = stream.read().decode() | ||
return pickle.loads(data) | ||
finally: | ||
stream.close() | ||
|
||
|
||
def test_cifar(sagemaker_session): | ||
with timeout(minutes=15): | ||
script_path = os.path.join(DATA_DIR, 'cifar_10', 'source') | ||
|
@@ -42,4 +67,11 @@ def test_cifar(sagemaker_session): | |
print('job succeeded: {}'.format(estimator.latest_training_job.name)) | ||
|
||
with timeout_and_delete_endpoint(estimator=estimator, minutes=20): | ||
estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge') | ||
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge') | ||
predictor.serializer = PickleSerializer() | ||
predictor.deserializer = PickleDeserializer() | ||
|
||
data = np.random.randn(32, 32, 3) | ||
predict_response = predictor.predict(data) | ||
|
||
assert len(predict_response.outputs['probabilities'].float_val) == 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change 'numpy' to 'object' for consistency with the above change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure