Skip to content

Commit 7be9781

Browse files
committed
SageMaker Python SDK tensorflow notebooks
1 parent 6084a45 commit 7be9781

26 files changed

+11249
-52
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import os
3+
import tensorflow as tf
4+
from tensorflow.contrib.keras.python.keras.layers import Dense
5+
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
6+
from tensorflow.python.estimator.export.export_output import PredictOutput
7+
8+
INPUT_TENSOR_NAME = "inputs"
9+
SIGNATURE_NAME = "serving_default"
10+
LEARNING_RATE = 0.001
11+
12+
13+
def model_fn(features, labels, mode, params):
14+
"""Model function for Estimator.
15+
# Logic to do the following:
16+
# 1. Configure the model via Keras functional api
17+
# 2. Define the loss function for training/evaluation using Tensorflow.
18+
# 3. Define the training operation/optimizer using Tensorflow operation/optimizer.
19+
# 4. Generate predictions as Tensorflow tensors.
20+
# 5. Generate necessary evaluation metrics.
21+
# 6. Return predictions/loss/train_op/eval_metric_ops in EstimatorSpec object"""
22+
23+
# 1. Configure the model via Keras functional api
24+
25+
first_hidden_layer = Dense(10, activation='relu', name='first-layer')(features[INPUT_TENSOR_NAME])
26+
second_hidden_layer = Dense(10, activation='relu')(first_hidden_layer)
27+
output_layer = Dense(1, activation='linear')(second_hidden_layer)
28+
29+
predictions = tf.reshape(output_layer, [-1])
30+
31+
# Provide an estimator spec for `ModeKeys.PREDICT`.
32+
if mode == tf.estimator.ModeKeys.PREDICT:
33+
return tf.estimator.EstimatorSpec(
34+
mode=mode,
35+
predictions={"ages": predictions},
36+
export_outputs={SIGNATURE_NAME: PredictOutput({"ages": predictions})})
37+
38+
# 2. Define the loss function for training/evaluation using Tensorflow.
39+
loss = tf.losses.mean_squared_error(labels, predictions)
40+
41+
# 3. Define the training operation/optimizer using Tensorflow operation/optimizer.
42+
train_op = tf.contrib.layers.optimize_loss(
43+
loss=loss,
44+
global_step=tf.contrib.framework.get_global_step(),
45+
learning_rate=params["learning_rate"],
46+
optimizer="SGD")
47+
48+
# 4. Generate predictions as Tensorflow tensors.
49+
predictions_dict = {"ages": predictions}
50+
51+
# 5. Generate necessary evaluation metrics.
52+
# Calculate root mean squared error as additional eval metric
53+
eval_metric_ops = {
54+
"rmse": tf.metrics.root_mean_squared_error(
55+
tf.cast(labels, tf.float32), predictions)
56+
}
57+
58+
# Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes.
59+
return tf.estimator.EstimatorSpec(
60+
mode=mode,
61+
loss=loss,
62+
train_op=train_op,
63+
eval_metric_ops=eval_metric_ops)
64+
65+
66+
def serving_input_fn(params):
67+
tensor = tf.placeholder(tf.float32, shape=[1, 7])
68+
return build_raw_serving_input_receiver_fn({INPUT_TENSOR_NAME: tensor})()
69+
70+
71+
params = {"learning_rate": LEARNING_RATE}
72+
73+
74+
def train_input_fn(training_dir, params):
75+
return _input_fn(training_dir, 'abalone_train.csv')
76+
77+
78+
def eval_input_fn(training_dir, params):
79+
return _input_fn(training_dir, 'abalone_test.csv')
80+
81+
82+
def _input_fn(training_dir, training_filename):
83+
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
84+
filename=os.path.join(training_dir, training_filename), target_dtype=np.int, features_dtype=np.float32)
85+
86+
return tf.estimator.inputs.numpy_input_fn(
87+
x={INPUT_TENSOR_NAME: np.array(training_set.data)},
88+
y=np.array(training_set.target),
89+
num_epochs=None,
90+
shuffle=True)()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
0.17,0.105,0.035,0.034,0.012,0.0085,0.005,4
2+
0.465,0.39,0.14,0.5555,0.213,0.1075,0.215,15
3+
0.365,0.28,0.09,0.196,0.0865,0.036,0.0605,7
4+
0.525,0.405,0.16,0.658,0.2655,0.1125,0.225,12
5+
0.585,0.45,0.175,1.1275,0.4925,0.262,0.335,11
6+
0.44,0.34,0.14,0.482,0.186,0.1085,0.16,9
7+
0.555,0.445,0.135,0.836,0.336,0.1625,0.275,13

0 commit comments

Comments
 (0)