Skip to content

Commit 0da5bd2

Browse files
author
Dan Choi
committed
Update copyright and add more links to docs
1 parent 0c4c63e commit 0da5bd2

File tree

7 files changed

+96
-27
lines changed

7 files changed

+96
-27
lines changed
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
# For more information on creating a Dockerfile
115
# https://docs.docker.com/compose/gettingstarted/#step-2-create-a-dockerfile
216
FROM tensorflow/tensorflow:1.8.0-py3
317

418
RUN apt-get update && apt-get install -y --no-install-recommends nginx curl
519

620
# Download tensorflow serving
21+
# https://www.tensorflow.org/serving/setup#installing_the_modelserver
722
RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list
823
RUN curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
924
RUN apt-get update && apt-get install tensorflow-model-server
1025

1126
ENV PATH="/opt/ml/code:${PATH}"
1227

28+
# /opt/ml and all subdirectories are utilized by SageMaker, we use the /code subdirectory to store our user code.
1329
COPY /cifar10 /opt/ml/code
14-
WORKDIR /opt/ml/code
30+
WORKDIR /opt/ml/code

advanced_functionality/tensorflow_bring_your_own/container/cifar10/cifar10.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
import argparse
1818
import functools
19-
2019
import os
2120

22-
import resnet_model
2321
import tensorflow as tf
2422

23+
import resnet_model
24+
2525
INPUT_TENSOR_NAME = "inputs"
2626
SIGNATURE_NAME = "serving_default"
2727

@@ -48,7 +48,10 @@
4848

4949

5050
def model_fn(features, labels, mode):
51-
"""Model function for CIFAR-10."""
51+
"""
52+
Model function for CIFAR-10.
53+
For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function
54+
"""
5255
inputs = features[INPUT_TENSOR_NAME]
5356
tf.summary.image('images', inputs, max_outputs=6)
5457

@@ -113,6 +116,10 @@ def model_fn(features, labels, mode):
113116

114117

115118
def serving_input_fn():
119+
"""
120+
Serving input function for CIFAR-10. Specifies the input format the caller of predict() will have to provide.
121+
For more information: https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel
122+
"""
116123
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
117124
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
118125

@@ -187,23 +194,30 @@ def main(model_dir, data_dir, train_steps):
187194

188195
if __name__ == '__main__':
189196
args_parser = argparse.ArgumentParser()
197+
# For more information:
198+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
190199
args_parser.add_argument(
191200
'--data-dir',
192201
default='/opt/ml/input/data/training',
193202
type=str,
194-
# required=True,
195-
help='The directory where the CIFAR-10 input data is stored.')
203+
help='The directory where the CIFAR-10 input data is stored. Default: /opt/ml/input/data/training. This '
204+
'directory corresponds to the SageMaker channel named \'training\', which was specified when creating '
205+
'our training job on SageMaker')
206+
207+
# For more information:
208+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
196209
args_parser.add_argument(
197210
'--model-dir',
198211
default='/opt/ml/model',
199212
type=str,
200-
# required=True,
201-
help='The directory where the model will be stored.')
213+
help='The directory where the model will be stored. Default: /opt/ml/model. This directory should contain all '
214+
'final model artifacts as Amazon SageMaker copies all data within this directory as a single object in '
215+
'compressed tar format.')
216+
202217
args_parser.add_argument(
203218
'--train-steps',
204219
type=int,
205220
default=100,
206-
# required=True,
207221
help='The number of steps to use for training.')
208222
args = args_parser.parse_args()
209223
main(**vars(args))

advanced_functionality/tensorflow_bring_your_own/container/cifar10/nginx.conf

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ events {
88
http {
99
server {
1010
# configures the server to listen to the port 8080
11-
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
11+
# Amazon SageMaker sends inference requests to port 8080.
12+
# For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
1213
listen 8080 deferred;
1314

1415
# redirects requests from SageMaker to TF Serving

advanced_functionality/tensorflow_bring_your_own/container/cifar10/resnet_model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# http://aws.amazon.com/apache2.0/
88
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
# ==============================================================================
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
# A sample training component that trains a simple scikit-learn decision tree model.
15+
# This implementation works in File mode and makes no assumptions about the input file names.
16+
# Input is specified as CSV with a data point in each row and the labels in the first column.
1517
"""Contains definitions for the preactivation form of Residual Networks.
1618
1719
Residual networks (ResNets) were originally proposed in:
@@ -30,6 +32,7 @@
3032
from __future__ import absolute_import
3133
from __future__ import division
3234
from __future__ import print_function
35+
3336
import tensorflow as tf
3437

3538
_BATCH_NORM_DECAY = 0.997

advanced_functionality/tensorflow_bring_your_own/container/cifar10/serve

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#!/usr/bin/env python
22

3-
# This file implements the hosting solution, which just starts TensorFlow Model Serving.
4-
5-
from __future__ import print_function
3+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You
6+
# may not use this file except in compliance with the License. A copy of
7+
# the License is located at
8+
#
9+
# http://aws.amazon.com/apache2.0/
10+
#
11+
# or in the "license" file accompanying this file. This file is
12+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
13+
# ANY KIND, either express or implied. See the License for the specific
14+
# language governing permissions and limitations under the License.
615

16+
# This file implements the hosting solution, which just starts TensorFlow Model Serving.
717
import subprocess
818

919

@@ -14,7 +24,13 @@ def start_server():
1424
subprocess.check_call(['ln', '-sf', '/dev/stdout', '/var/log/nginx/access.log'])
1525
subprocess.check_call(['ln', '-sf', '/dev/stderr', '/var/log/nginx/error.log'])
1626

27+
# start nginx server
1728
nginx = subprocess.Popen(['nginx', '-c', '/opt/ml/code/nginx.conf'])
29+
30+
# start TensorFlow Serving
31+
# https://www.tensorflow.org/serving/api_rest#start_modelserver_with_the_rest_api_endpoint
32+
# SageMaker copies our model artifact from our Training Job into the /opt/ml/model.
33+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-load-artifacts
1834
tf_model_server = subprocess.call(['tensorflow_model_server',
1935
'--rest_api_port=8501',
2036
'--model_name=cifar10_model',

advanced_functionality/tensorflow_bring_your_own/container/cifar10/train

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
#!/usr/bin/env python
22

3+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You
6+
# may not use this file except in compliance with the License. A copy of
7+
# the License is located at
8+
#
9+
# http://aws.amazon.com/apache2.0/
10+
#
11+
# or in the "license" file accompanying this file. This file is
12+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
13+
# ANY KIND, either express or implied. See the License for the specific
14+
# language governing permissions and limitations under the License.
15+
316
# A sample training component that trains a simple scikit-learn decision tree model.
417
# This implementation works in File mode and makes no assumptions about the input file names.
518
# Input is specified as CSV with a data point in each row and the labels in the first column.
@@ -31,6 +44,7 @@ default_params = ['--model-dir', str(model_path)]
3144

3245
# Execute your training algorithm.
3346
def _run(cmd):
47+
"""Invokes your training algorithm."""
3448
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ)
3549
stdout, stderr = process.communicate()
3650

@@ -41,6 +55,10 @@ def _run(cmd):
4155

4256

4357
def _hyperparameters_to_cmd_args(hyperparameters):
58+
"""
59+
Converts our hyperparameters, in json format, into key-value pair suitable for passing to our training
60+
algorithm.
61+
"""
4462
cmd_args_list = []
4563

4664
for key, value in hyperparameters.items():
@@ -52,6 +70,9 @@ def _hyperparameters_to_cmd_args(hyperparameters):
5270

5371
if __name__ == '__main__':
5472
try:
73+
# Amazon SageMaker makes our specified hyperparameters available within the
74+
# /opt/ml/input/config/hyperparameters.json.
75+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container
5576
with open(param_path, 'r') as tc:
5677
training_params = json.load(tc)
5778

advanced_functionality/tensorflow_bring_your_own/utils/generate_cifar10_tfrecords.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
python version of the CIFAR-10 dataset downloaded from
1919
https://www.cs.toronto.edu/~kriz/cifar.html.
2020
"""
21-
2221
from __future__ import absolute_import
2322
from __future__ import division
2423
from __future__ import print_function
2524

2625
import argparse
2726
import os
27+
import shutil
2828
import sys
29-
3029
import tarfile
3130

32-
import shutil
3331
from six.moves import cPickle as pickle
3432
from six.moves import xrange # pylint: disable=redefined-builtin
3533
import tensorflow as tf

0 commit comments

Comments
 (0)