diff --git a/tests/data/marketplace/training/iris.csv b/tests/data/marketplace/training/iris.csv new file mode 100644 index 0000000000..6abe4af5f3 --- /dev/null +++ b/tests/data/marketplace/training/iris.csv @@ -0,0 +1,150 @@ +setosa,5.1,3.5,1.4,0.2 +setosa,4.9,3,1.4,0.2 +setosa,4.7,3.2,1.3,0.2 +setosa,4.6,3.1,1.5,0.2 +setosa,5,3.6,1.4,0.2 +setosa,5.4,3.9,1.7,0.4 +setosa,4.6,3.4,1.4,0.3 +setosa,5,3.4,1.5,0.2 +setosa,4.4,2.9,1.4,0.2 +setosa,4.9,3.1,1.5,0.1 +setosa,5.4,3.7,1.5,0.2 +setosa,4.8,3.4,1.6,0.2 +setosa,4.8,3,1.4,0.1 +setosa,4.3,3,1.1,0.1 +setosa,5.8,4,1.2,0.2 +setosa,5.7,4.4,1.5,0.4 +setosa,5.4,3.9,1.3,0.4 +setosa,5.1,3.5,1.4,0.3 +setosa,5.7,3.8,1.7,0.3 +setosa,5.1,3.8,1.5,0.3 +setosa,5.4,3.4,1.7,0.2 +setosa,5.1,3.7,1.5,0.4 +setosa,4.6,3.6,1,0.2 +setosa,5.1,3.3,1.7,0.5 +setosa,4.8,3.4,1.9,0.2 +setosa,5,3,1.6,0.2 +setosa,5,3.4,1.6,0.4 +setosa,5.2,3.5,1.5,0.2 +setosa,5.2,3.4,1.4,0.2 +setosa,4.7,3.2,1.6,0.2 +setosa,4.8,3.1,1.6,0.2 +setosa,5.4,3.4,1.5,0.4 +setosa,5.2,4.1,1.5,0.1 +setosa,5.5,4.2,1.4,0.2 +setosa,4.9,3.1,1.5,0.2 +setosa,5,3.2,1.2,0.2 +setosa,5.5,3.5,1.3,0.2 +setosa,4.9,3.6,1.4,0.1 +setosa,4.4,3,1.3,0.2 +setosa,5.1,3.4,1.5,0.2 +setosa,5,3.5,1.3,0.3 +setosa,4.5,2.3,1.3,0.3 +setosa,4.4,3.2,1.3,0.2 +setosa,5,3.5,1.6,0.6 +setosa,5.1,3.8,1.9,0.4 +setosa,4.8,3,1.4,0.3 +setosa,5.1,3.8,1.6,0.2 +setosa,4.6,3.2,1.4,0.2 +setosa,5.3,3.7,1.5,0.2 +setosa,5,3.3,1.4,0.2 +versicolor,7,3.2,4.7,1.4 +versicolor,6.4,3.2,4.5,1.5 +versicolor,6.9,3.1,4.9,1.5 +versicolor,5.5,2.3,4,1.3 +versicolor,6.5,2.8,4.6,1.5 +versicolor,5.7,2.8,4.5,1.3 +versicolor,6.3,3.3,4.7,1.6 +versicolor,4.9,2.4,3.3,1 +versicolor,6.6,2.9,4.6,1.3 +versicolor,5.2,2.7,3.9,1.4 +versicolor,5,2,3.5,1 +versicolor,5.9,3,4.2,1.5 +versicolor,6,2.2,4,1 +versicolor,6.1,2.9,4.7,1.4 +versicolor,5.6,2.9,3.6,1.3 +versicolor,6.7,3.1,4.4,1.4 +versicolor,5.6,3,4.5,1.5 +versicolor,5.8,2.7,4.1,1 +versicolor,6.2,2.2,4.5,1.5 +versicolor,5.6,2.5,3.9,1.1 +versicolor,5.9,3.2,4.8,1.8 +versicolor,6.1,2.8,4,1.3 +versicolor,6.3,2.5,4.9,1.5 +versicolor,6.1,2.8,4.7,1.2 +versicolor,6.4,2.9,4.3,1.3 +versicolor,6.6,3,4.4,1.4 +versicolor,6.8,2.8,4.8,1.4 +versicolor,6.7,3,5,1.7 +versicolor,6,2.9,4.5,1.5 +versicolor,5.7,2.6,3.5,1 +versicolor,5.5,2.4,3.8,1.1 +versicolor,5.5,2.4,3.7,1 +versicolor,5.8,2.7,3.9,1.2 +versicolor,6,2.7,5.1,1.6 +versicolor,5.4,3,4.5,1.5 +versicolor,6,3.4,4.5,1.6 +versicolor,6.7,3.1,4.7,1.5 +versicolor,6.3,2.3,4.4,1.3 +versicolor,5.6,3,4.1,1.3 +versicolor,5.5,2.5,4,1.3 +versicolor,5.5,2.6,4.4,1.2 +versicolor,6.1,3,4.6,1.4 +versicolor,5.8,2.6,4,1.2 +versicolor,5,2.3,3.3,1 +versicolor,5.6,2.7,4.2,1.3 +versicolor,5.7,3,4.2,1.2 +versicolor,5.7,2.9,4.2,1.3 +versicolor,6.2,2.9,4.3,1.3 +versicolor,5.1,2.5,3,1.1 +versicolor,5.7,2.8,4.1,1.3 +virginica,6.3,3.3,6,2.5 +virginica,5.8,2.7,5.1,1.9 +virginica,7.1,3,5.9,2.1 +virginica,6.3,2.9,5.6,1.8 +virginica,6.5,3,5.8,2.2 +virginica,7.6,3,6.6,2.1 +virginica,4.9,2.5,4.5,1.7 +virginica,7.3,2.9,6.3,1.8 +virginica,6.7,2.5,5.8,1.8 +virginica,7.2,3.6,6.1,2.5 +virginica,6.5,3.2,5.1,2 +virginica,6.4,2.7,5.3,1.9 +virginica,6.8,3,5.5,2.1 +virginica,5.7,2.5,5,2 +virginica,5.8,2.8,5.1,2.4 +virginica,6.4,3.2,5.3,2.3 +virginica,6.5,3,5.5,1.8 +virginica,7.7,3.8,6.7,2.2 +virginica,7.7,2.6,6.9,2.3 +virginica,6,2.2,5,1.5 +virginica,6.9,3.2,5.7,2.3 +virginica,5.6,2.8,4.9,2 +virginica,7.7,2.8,6.7,2 +virginica,6.3,2.7,4.9,1.8 +virginica,6.7,3.3,5.7,2.1 +virginica,7.2,3.2,6,1.8 +virginica,6.2,2.8,4.8,1.8 +virginica,6.1,3,4.9,1.8 +virginica,6.4,2.8,5.6,2.1 +virginica,7.2,3,5.8,1.6 +virginica,7.4,2.8,6.1,1.9 +virginica,7.9,3.8,6.4,2 +virginica,6.4,2.8,5.6,2.2 +virginica,6.3,2.8,5.1,1.5 +virginica,6.1,2.6,5.6,1.4 +virginica,7.7,3,6.1,2.3 +virginica,6.3,3.4,5.6,2.4 +virginica,6.4,3.1,5.5,1.8 +virginica,6,3,4.8,1.8 +virginica,6.9,3.1,5.4,2.1 +virginica,6.7,3.1,5.6,2.4 +virginica,6.9,3.1,5.1,2.3 +virginica,5.8,2.7,5.1,1.9 +virginica,6.8,3.2,5.9,2.3 +virginica,6.7,3.3,5.7,2.5 +virginica,6.7,3,5.2,2.3 +virginica,6.3,2.5,5,1.9 +virginica,6.5,3,5.2,2 +virginica,6.2,3.4,5.4,2.3 +virginica,5.9,3,5.1,1.8 diff --git a/tests/data/marketplace/transform/batchtransform_test.csv b/tests/data/marketplace/transform/batchtransform_test.csv new file mode 100644 index 0000000000..e76f2f2670 --- /dev/null +++ b/tests/data/marketplace/transform/batchtransform_test.csv @@ -0,0 +1,150 @@ +5.1,3.5,1.4,0.2 +4.9,3.0,1.4,0.2 +4.7,3.2,1.3,0.2 +4.6,3.1,1.5,0.2 +5.0,3.6,1.4,0.2 +5.4,3.9,1.7,0.4 +4.6,3.4,1.4,0.3 +5.0,3.4,1.5,0.2 +4.4,2.9,1.4,0.2 +4.9,3.1,1.5,0.1 +5.4,3.7,1.5,0.2 +4.8,3.4,1.6,0.2 +4.8,3.0,1.4,0.1 +4.3,3.0,1.1,0.1 +5.8,4.0,1.2,0.2 +5.7,4.4,1.5,0.4 +5.4,3.9,1.3,0.4 +5.1,3.5,1.4,0.3 +5.7,3.8,1.7,0.3 +5.1,3.8,1.5,0.3 +5.4,3.4,1.7,0.2 +5.1,3.7,1.5,0.4 +4.6,3.6,1.0,0.2 +5.1,3.3,1.7,0.5 +4.8,3.4,1.9,0.2 +5.0,3.0,1.6,0.2 +5.0,3.4,1.6,0.4 +5.2,3.5,1.5,0.2 +5.2,3.4,1.4,0.2 +4.7,3.2,1.6,0.2 +4.8,3.1,1.6,0.2 +5.4,3.4,1.5,0.4 +5.2,4.1,1.5,0.1 +5.5,4.2,1.4,0.2 +4.9,3.1,1.5,0.2 +5.0,3.2,1.2,0.2 +5.5,3.5,1.3,0.2 +4.9,3.6,1.4,0.1 +4.4,3.0,1.3,0.2 +5.1,3.4,1.5,0.2 +5.0,3.5,1.3,0.3 +4.5,2.3,1.3,0.3 +4.4,3.2,1.3,0.2 +5.0,3.5,1.6,0.6 +5.1,3.8,1.9,0.4 +4.8,3.0,1.4,0.3 +5.1,3.8,1.6,0.2 +4.6,3.2,1.4,0.2 +5.3,3.7,1.5,0.2 +5.0,3.3,1.4,0.2 +7.0,3.2,4.7,1.4 +6.4,3.2,4.5,1.5 +6.9,3.1,4.9,1.5 +5.5,2.3,4.0,1.3 +6.5,2.8,4.6,1.5 +5.7,2.8,4.5,1.3 +6.3,3.3,4.7,1.6 +4.9,2.4,3.3,1.0 +6.6,2.9,4.6,1.3 +5.2,2.7,3.9,1.4 +5.0,2.0,3.5,1.0 +5.9,3.0,4.2,1.5 +6.0,2.2,4.0,1.0 +6.1,2.9,4.7,1.4 +5.6,2.9,3.6,1.3 +6.7,3.1,4.4,1.4 +5.6,3.0,4.5,1.5 +5.8,2.7,4.1,1.0 +6.2,2.2,4.5,1.5 +5.6,2.5,3.9,1.1 +5.9,3.2,4.8,1.8 +6.1,2.8,4.0,1.3 +6.3,2.5,4.9,1.5 +6.1,2.8,4.7,1.2 +6.4,2.9,4.3,1.3 +6.6,3.0,4.4,1.4 +6.8,2.8,4.8,1.4 +6.7,3.0,5.0,1.7 +6.0,2.9,4.5,1.5 +5.7,2.6,3.5,1.0 +5.5,2.4,3.8,1.1 +5.5,2.4,3.7,1.0 +5.8,2.7,3.9,1.2 +6.0,2.7,5.1,1.6 +5.4,3.0,4.5,1.5 +6.0,3.4,4.5,1.6 +6.7,3.1,4.7,1.5 +6.3,2.3,4.4,1.3 +5.6,3.0,4.1,1.3 +5.5,2.5,4.0,1.3 +5.5,2.6,4.4,1.2 +6.1,3.0,4.6,1.4 +5.8,2.6,4.0,1.2 +5.0,2.3,3.3,1.0 +5.6,2.7,4.2,1.3 +5.7,3.0,4.2,1.2 +5.7,2.9,4.2,1.3 +6.2,2.9,4.3,1.3 +5.1,2.5,3.0,1.1 +5.7,2.8,4.1,1.3 +6.3,3.3,6.0,2.5 +5.8,2.7,5.1,1.9 +7.1,3.0,5.9,2.1 +6.3,2.9,5.6,1.8 +6.5,3.0,5.8,2.2 +7.6,3.0,6.6,2.1 +4.9,2.5,4.5,1.7 +7.3,2.9,6.3,1.8 +6.7,2.5,5.8,1.8 +7.2,3.6,6.1,2.5 +6.5,3.2,5.1,2.0 +6.4,2.7,5.3,1.9 +6.8,3.0,5.5,2.1 +5.7,2.5,5.0,2.0 +5.8,2.8,5.1,2.4 +6.4,3.2,5.3,2.3 +6.5,3.0,5.5,1.8 +7.7,3.8,6.7,2.2 +7.7,2.6,6.9,2.3 +6.0,2.2,5.0,1.5 +6.9,3.2,5.7,2.3 +5.6,2.8,4.9,2.0 +7.7,2.8,6.7,2.0 +6.3,2.7,4.9,1.8 +6.7,3.3,5.7,2.1 +7.2,3.2,6.0,1.8 +6.2,2.8,4.8,1.8 +6.1,3.0,4.9,1.8 +6.4,2.8,5.6,2.1 +7.2,3.0,5.8,1.6 +7.4,2.8,6.1,1.9 +7.9,3.8,6.4,2.0 +6.4,2.8,5.6,2.2 +6.3,2.8,5.1,1.5 +6.1,2.6,5.6,1.4 +7.7,3.0,6.1,2.3 +6.3,3.4,5.6,2.4 +6.4,3.1,5.5,1.8 +6.0,3.0,4.8,1.8 +6.9,3.1,5.4,2.1 +6.7,3.1,5.6,2.4 +6.9,3.1,5.1,2.3 +5.8,2.7,5.1,1.9 +6.8,3.2,5.9,2.3 +6.7,3.3,5.7,2.5 +6.7,3.0,5.2,2.3 +6.3,2.5,5.0,1.9 +6.5,3.0,5.2,2.0 +6.2,3.4,5.4,2.3 +5.9,3.0,5.1,1.8 diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py new file mode 100644 index 0000000000..096c97a89f --- /dev/null +++ b/tests/integ/test_marketplace.py @@ -0,0 +1,215 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import itertools +import os +import time + +import pandas + +import sagemaker +from sagemaker import AlgorithmEstimator, ModelPackage +from sagemaker.tuner import IntegerParameter, HyperparameterTuner +from sagemaker.utils import sagemaker_timestamp +from tests.integ import DATA_DIR +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name + + +# All these tests require a manual 1 time subscription to the following Marketplace items: +# Algorithm: Scikit Decision Trees +# https://aws.amazon.com/marketplace/pp/prodview-ha4f3kqugba3u +# +# Pre-Trained Model: Scikit Decision Trees - Pretrained Model +# https://aws.amazon.com/marketplace/pp/prodview-7qop4x5ahrdhe +# +# Both are written by Amazon and are free to subscribe. + +ALGORITHM_ARN = 'arn:aws:sagemaker:%s:594846645681:algorithm/scikit-decision-trees-' \ + '15423055-57b73412d2e93e9239e4e16f83298b8f' + +MODEL_PACKAGE_ARN = 'arn:aws:sagemaker:%s:594846645681:model-package/scikit-iris-detector-' \ + '154230595-8f00905c1f927a512b73ea29dd09ae30' + + +def test_marketplace_estimator(sagemaker_session): + with timeout(minutes=15): + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + + algo = AlgorithmEstimator( + algorithm_arn=(ALGORITHM_ARN % sagemaker_session.boto_region_name), + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + sagemaker_session=sagemaker_session) + + train_input = algo.sagemaker_session.upload_data( + path=data_path, key_prefix='integ-test-data/marketplace/train') + + algo.fit({'training': train_input}) + + endpoint_name = 'test-marketplace-estimator{}'.format(sagemaker_timestamp()) + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + predictor = algo.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + + a = [50 * i for i in range(3)] + b = [40 + i for i in range(10)] + indices = [i + j for i, j in itertools.product(a, b)] + + test_data = shape.iloc[indices[:-1]] + test_x = test_data.iloc[:, 1:] + + print(predictor.predict(test_x.values).decode('utf-8')) + + +def test_marketplace_attach(sagemaker_session): + with timeout(minutes=15): + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + + mktplace = AlgorithmEstimator( + algorithm_arn=(ALGORITHM_ARN % sagemaker_session.boto_region_name), + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + sagemaker_session=sagemaker_session, + base_job_name='test-marketplace') + + train_input = mktplace.sagemaker_session.upload_data( + path=data_path, key_prefix='integ-test-data/marketplace/train') + + mktplace.fit({'training': train_input}, wait=False) + training_job_name = mktplace.latest_training_job.name + + print('Waiting to re-attach to the training job: %s' % training_job_name) + time.sleep(20) + endpoint_name = 'test-marketplace-estimator{}'.format(sagemaker_timestamp()) + + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + print('Re-attaching now to: %s' % training_job_name) + estimator = AlgorithmEstimator.attach(training_job_name=training_job_name, + sagemaker_session=sagemaker_session) + predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, + serializer=sagemaker.predictor.csv_serializer) + shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + a = [50 * i for i in range(3)] + b = [40 + i for i in range(10)] + indices = [i + j for i, j in itertools.product(a, b)] + + test_data = shape.iloc[indices[:-1]] + test_x = test_data.iloc[:, 1:] + + print(predictor.predict(test_x.values).decode('utf-8')) + + +def test_marketplace_model(sagemaker_session): + + def predict_wrapper(endpoint, session): + return sagemaker.RealTimePredictor( + endpoint, session, serializer=sagemaker.predictor.csv_serializer + ) + + model = ModelPackage(role='SageMakerRole', + model_package_arn=(MODEL_PACKAGE_ARN % sagemaker_session.boto_region_name), + sagemaker_session=sagemaker_session, + predictor_cls=predict_wrapper) + + endpoint_name = 'test-marketplace-model-endpoint{}'.format(sagemaker_timestamp()) + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + a = [50 * i for i in range(3)] + b = [40 + i for i in range(10)] + indices = [i + j for i, j in itertools.product(a, b)] + + test_data = shape.iloc[indices[:-1]] + test_x = test_data.iloc[:, 1:] + + print(predictor.predict(test_x.values).decode('utf-8')) + + +def test_marketplace_tuning_job(sagemaker_session): + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + + mktplace = AlgorithmEstimator( + algorithm_arn=(ALGORITHM_ARN % sagemaker_session.boto_region_name), + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + sagemaker_session=sagemaker_session, + base_job_name='test-marketplace') + + train_input = mktplace.sagemaker_session.upload_data( + path=data_path, key_prefix='integ-test-data/marketplace/train') + + mktplace.set_hyperparameters(max_leaf_nodes=10) + + hyperparameter_ranges = {'max_leaf_nodes': IntegerParameter(1, 100000)} + + tuner = HyperparameterTuner(estimator=mktplace, base_tuning_job_name='byo', + objective_metric_name='validation:accuracy', + hyperparameter_ranges=hyperparameter_ranges, + max_jobs=2, max_parallel_jobs=2) + + tuner.fit({'training': train_input}, include_cls_metadata=False) + time.sleep(15) + tuner.wait() + + +def test_marketplace_transform_job(sagemaker_session): + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + + algo = AlgorithmEstimator( + algorithm_arn=(ALGORITHM_ARN % sagemaker_session.boto_region_name), + role='SageMakerRole', + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + sagemaker_session=sagemaker_session, + base_job_name='test-marketplace') + + train_input = algo.sagemaker_session.upload_data( + path=data_path, key_prefix='integ-test-data/marketplace/train') + + shape = pandas.read_csv(data_path + '/iris.csv', header=None).drop([0], axis=1) + + transform_workdir = DATA_DIR + '/marketplace/transform' + shape.to_csv(transform_workdir + '/batchtransform_test.csv', index=False, header=False) + transform_input = algo.sagemaker_session.upload_data( + transform_workdir, + key_prefix='integ-test-data/marketplace/transform') + + algo.fit({'training': train_input}) + + transformer = algo.transformer(1, 'ml.m4.xlarge') + transformer.transform(transform_input, content_type='text/csv') + transformer.wait() + + +def test_marketplace_transform_job_from_model_package(sagemaker_session): + data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + shape = pandas.read_csv(data_path + '/iris.csv', header=None).drop([0], axis=1) + + TRANSFORM_WORKDIR = DATA_DIR + '/marketplace/transform' + shape.to_csv(TRANSFORM_WORKDIR + '/batchtransform_test.csv', index=False, header=False) + transform_input = sagemaker_session.upload_data( + TRANSFORM_WORKDIR, + key_prefix='integ-test-data/marketplace/transform') + + model = ModelPackage(role='SageMakerRole', + model_package_arn=(MODEL_PACKAGE_ARN % sagemaker_session.boto_region_name), + sagemaker_session=sagemaker_session) + + transformer = model.transformer(1, 'ml.m4.xlarge') + transformer.transform(transform_input, content_type='text/csv') + transformer.wait()