Skip to content

Commit ea0c5f9

Browse files
authored
Add test for BYO estimator using Factorization Machines algorithm as an example. (#50)
1 parent a23028a commit ea0c5f9

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

tests/integ/test_byo_estimator.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import gzip
14+
import io
15+
import json
16+
import numpy as np
17+
import os
18+
import pickle
19+
import sys
20+
21+
import boto3
22+
23+
import sagemaker
24+
from sagemaker.estimator import Estimator
25+
from sagemaker.amazon.amazon_estimator import registry
26+
from sagemaker.amazon.common import write_numpy_to_dense_tensor
27+
from sagemaker.utils import name_from_base
28+
from tests.integ import DATA_DIR, REGION
29+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
30+
31+
32+
def test_byo_estimator():
33+
"""Use Factorization Machines algorithm as an example here.
34+
35+
First we need to prepare data for training. We take standard data set, convert it to the
36+
format that the algorithm can process and upload it to S3.
37+
Then we create the Estimator and set hyperparamets as required by the algorithm.
38+
Next, we can call fit() with path to the S3.
39+
Later the trained model is deployed and prediction is called against the endpoint.
40+
Default predictor is updated with json serializer and deserializer.
41+
42+
"""
43+
image_name = registry(REGION) + "/factorization-machines:1"
44+
45+
with timeout(minutes=15):
46+
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
47+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
48+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
49+
50+
with gzip.open(data_path, 'rb') as f:
51+
train_set, _, _ = pickle.load(f, **pickle_args)
52+
53+
# take 100 examples for faster execution
54+
vectors = np.array([t.tolist() for t in train_set[0][:100]]).astype('float32')
55+
labels = np.where(np.array([t.tolist() for t in train_set[1][:100]]) == 0, 1.0, 0.0).astype('float32')
56+
57+
buf = io.BytesIO()
58+
write_numpy_to_dense_tensor(buf, vectors, labels)
59+
buf.seek(0)
60+
61+
bucket = sagemaker_session.default_bucket()
62+
prefix = 'test_byo_estimator'
63+
key = 'recordio-pb-data'
64+
boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf)
65+
s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key)
66+
67+
estimator = Estimator(image_name=image_name,
68+
role='SageMakerRole', train_instance_count=1,
69+
train_instance_type='ml.c4.xlarge',
70+
sagemaker_session=sagemaker_session, base_job_name='test-byo')
71+
72+
estimator.set_hyperparameters(num_factors=10,
73+
feature_dim=784,
74+
mini_batch_size=100,
75+
predictor_type='binary_classifier')
76+
77+
# training labels must be 'float32'
78+
estimator.fit({'train': s3_train_data})
79+
80+
endpoint_name = name_from_base('byo')
81+
82+
def fm_serializer(data):
83+
js = {'instances': []}
84+
for row in data:
85+
js['instances'].append({'features': row.tolist()})
86+
return json.dumps(js)
87+
88+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
89+
model = estimator.create_model()
90+
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
91+
predictor.serializer = fm_serializer
92+
predictor.content_type = 'application/json'
93+
predictor.deserializer = sagemaker.predictor.json_deserializer
94+
95+
result = predictor.predict(train_set[0][:10])
96+
97+
assert len(result['predictions']) == 10
98+
for prediction in result['predictions']:
99+
assert prediction['score'] is not None

0 commit comments

Comments
 (0)