Skip to content
This repository was archived by the owner on Oct 4, 2024. It is now read-only.

Commit b314b73

Browse files
committed
workflow with sm pipelines: training script file
1 parent 34dabe4 commit b314b73

File tree

1 file changed

+79
-0
lines changed
  • tf-2-workflow-smpipelines/train_model

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
import numpy as np
3+
import os
4+
import tensorflow as tf
5+
6+
from model_def import get_model
7+
8+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9+
10+
11+
def parse_args():
12+
13+
parser = argparse.ArgumentParser()
14+
15+
# hyperparameters sent by the client are passed as command-line arguments to the script
16+
parser.add_argument('--epochs', type=int, default=1)
17+
parser.add_argument('--batch_size', type=int, default=64)
18+
parser.add_argument('--learning_rate', type=float, default=0.1)
19+
20+
# data directories
21+
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
22+
parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST'))
23+
24+
# model directory
25+
parser.add_argument('--sm-model-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
26+
27+
return parser.parse_known_args()
28+
29+
30+
def get_train_data(train_dir):
31+
32+
x_train = np.load(os.path.join(train_dir, 'x_train.npy'))
33+
y_train = np.load(os.path.join(train_dir, 'y_train.npy'))
34+
print('x train', x_train.shape,'y train', y_train.shape)
35+
36+
return x_train, y_train
37+
38+
39+
def get_test_data(test_dir):
40+
41+
x_test = np.load(os.path.join(test_dir, 'x_test.npy'))
42+
y_test = np.load(os.path.join(test_dir, 'y_test.npy'))
43+
print('x test', x_test.shape,'y test', y_test.shape)
44+
45+
return x_test, y_test
46+
47+
48+
if __name__ == "__main__":
49+
50+
args, _ = parse_args()
51+
52+
print('Training data location: {}'.format(args.train))
53+
print('Test data location: {}'.format(args.test))
54+
x_train, y_train = get_train_data(args.train)
55+
x_test, y_test = get_test_data(args.test)
56+
57+
device = '/cpu:0'
58+
print(device)
59+
batch_size = args.batch_size
60+
epochs = args.epochs
61+
learning_rate = args.learning_rate
62+
print('batch_size = {}, epochs = {}, learning rate = {}'.format(batch_size, epochs, learning_rate))
63+
64+
with tf.device(device):
65+
66+
model = get_model()
67+
optimizer = tf.keras.optimizers.SGD(learning_rate)
68+
model.compile(optimizer=optimizer, loss='mse')
69+
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
70+
validation_data=(x_test, y_test))
71+
72+
# evaluate on test set
73+
scores = model.evaluate(x_test, y_test, batch_size, verbose=2)
74+
print("\nTest MSE :", scores)
75+
76+
# save model
77+
model.save(args.sm_model_dir + '/1')
78+
79+

0 commit comments

Comments
 (0)