|
| 1 | +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +import gzip |
| 14 | +import io |
| 15 | +import json |
| 16 | +import numpy as np |
| 17 | +import os |
| 18 | +import pickle |
| 19 | +import sys |
| 20 | + |
| 21 | +import boto3 |
| 22 | + |
| 23 | +import sagemaker |
| 24 | +from sagemaker.estimator import Estimator |
| 25 | +from sagemaker.amazon.amazon_estimator import registry |
| 26 | +from sagemaker.amazon.common import write_numpy_to_dense_tensor |
| 27 | +from sagemaker.utils import name_from_base |
| 28 | +from tests.integ import DATA_DIR, REGION |
| 29 | +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name |
| 30 | + |
| 31 | + |
| 32 | +def test_byo_estimator(): |
| 33 | + """Use Factorization Machines algorithm as an example here. |
| 34 | +
|
| 35 | + First we need to prepare data for training. We take standard data set, convert it to the |
| 36 | + format that the algorithm can process and upload it to S3. |
| 37 | + Then we create the Estimator and set hyperparamets as required by the algorithm. |
| 38 | + Next, we can call fit() with path to the S3. |
| 39 | + Later the trained model is deployed and prediction is called against the endpoint. |
| 40 | + Default predictor is updated with json serializer and deserializer. |
| 41 | +
|
| 42 | + """ |
| 43 | + image_name = registry(REGION) + "/factorization-machines:1" |
| 44 | + |
| 45 | + with timeout(minutes=15): |
| 46 | + sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) |
| 47 | + data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') |
| 48 | + pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} |
| 49 | + |
| 50 | + with gzip.open(data_path, 'rb') as f: |
| 51 | + train_set, _, _ = pickle.load(f, **pickle_args) |
| 52 | + |
| 53 | + # take 100 examples for faster execution |
| 54 | + vectors = np.array([t.tolist() for t in train_set[0][:100]]).astype('float32') |
| 55 | + labels = np.where(np.array([t.tolist() for t in train_set[1][:100]]) == 0, 1.0, 0.0).astype('float32') |
| 56 | + |
| 57 | + buf = io.BytesIO() |
| 58 | + write_numpy_to_dense_tensor(buf, vectors, labels) |
| 59 | + buf.seek(0) |
| 60 | + |
| 61 | + bucket = sagemaker_session.default_bucket() |
| 62 | + prefix = 'test_byo_estimator' |
| 63 | + key = 'recordio-pb-data' |
| 64 | + boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf) |
| 65 | + s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key) |
| 66 | + |
| 67 | + estimator = Estimator(image_name=image_name, |
| 68 | + role='SageMakerRole', train_instance_count=1, |
| 69 | + train_instance_type='ml.c4.xlarge', |
| 70 | + sagemaker_session=sagemaker_session, base_job_name='test-byo') |
| 71 | + |
| 72 | + estimator.set_hyperparameters(num_factors=10, |
| 73 | + feature_dim=784, |
| 74 | + mini_batch_size=100, |
| 75 | + predictor_type='binary_classifier') |
| 76 | + |
| 77 | + # training labels must be 'float32' |
| 78 | + estimator.fit({'train': s3_train_data}) |
| 79 | + |
| 80 | + endpoint_name = name_from_base('byo') |
| 81 | + |
| 82 | + def fm_serializer(data): |
| 83 | + js = {'instances': []} |
| 84 | + for row in data: |
| 85 | + js['instances'].append({'features': row.tolist()}) |
| 86 | + return json.dumps(js) |
| 87 | + |
| 88 | + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): |
| 89 | + model = estimator.create_model() |
| 90 | + predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) |
| 91 | + predictor.serializer = fm_serializer |
| 92 | + predictor.content_type = 'application/json' |
| 93 | + predictor.deserializer = sagemaker.predictor.json_deserializer |
| 94 | + |
| 95 | + result = predictor.predict(train_set[0][:10]) |
| 96 | + |
| 97 | + assert len(result['predictions']) == 10 |
| 98 | + for prediction in result['predictions']: |
| 99 | + assert prediction['score'] is not None |
0 commit comments