diff --git a/examples/cli/host/data/model.json b/examples/cli/host/data/model.json new file mode 100644 index 0000000000..8d0d57a51e --- /dev/null +++ b/examples/cli/host/data/model.json @@ -0,0 +1,126 @@ +{ + "nodes": [ + { + "op": "null", + "name": "data", + "inputs": [] + }, + { + "op": "null", + "name": "sequential0_dense0_weight", + "attr": { + "__dtype__": "0", + "__lr_mult__": "1.0", + "__shape__": "(128, 0)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "null", + "name": "sequential0_dense0_bias", + "attr": { + "__dtype__": "0", + "__init__": "zeros", + "__lr_mult__": "1.0", + "__shape__": "(128,)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "sequential0_dense0_fwd", + "attr": {"num_hidden": "128"}, + "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] + }, + { + "op": "Activation", + "name": "sequential0_dense0_relu_fwd", + "attr": {"act_type": "relu"}, + "inputs": [[3, 0, 0]] + }, + { + "op": "null", + "name": "sequential0_dense1_weight", + "attr": { + "__dtype__": "0", + "__lr_mult__": "1.0", + "__shape__": "(64, 0)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "null", + "name": "sequential0_dense1_bias", + "attr": { + "__dtype__": "0", + "__init__": "zeros", + "__lr_mult__": "1.0", + "__shape__": "(64,)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "sequential0_dense1_fwd", + "attr": {"num_hidden": "64"}, + "inputs": [[4, 0, 0], [5, 0, 0], [6, 0, 0]] + }, + { + "op": "Activation", + "name": "sequential0_dense1_relu_fwd", + "attr": {"act_type": "relu"}, + "inputs": [[7, 0, 0]] + }, + { + "op": "null", + "name": "sequential0_dense2_weight", + "attr": { + "__dtype__": "0", + "__lr_mult__": "1.0", + "__shape__": "(10, 0)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "null", + "name": "sequential0_dense2_bias", + "attr": { + "__dtype__": "0", + "__init__": "zeros", + "__lr_mult__": "1.0", + "__shape__": "(10,)", + "__wd_mult__": "1.0" + }, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "sequential0_dense2_fwd", + "attr": {"num_hidden": "10"}, + "inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0]] + } + ], + "arg_nodes": [0, 1, 2, 5, 6, 9, 10], + "node_row_ptr": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12 + ], + "heads": [[11, 0, 0]], + "attrs": {"mxnet_version": ["int", 1100]} +} \ No newline at end of file diff --git a/examples/cli/host/data/model.params b/examples/cli/host/data/model.params new file mode 100644 index 0000000000..3757d543c8 Binary files /dev/null and b/examples/cli/host/data/model.params differ diff --git a/examples/cli/host/run_hosting_example.sh b/examples/cli/host/run_hosting_example.sh new file mode 100644 index 0000000000..b6d7e92d4d --- /dev/null +++ b/examples/cli/host/run_hosting_example.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +sagemaker mxnet host --role-name diff --git a/examples/cli/host/script.py b/examples/cli/host/script.py new file mode 100644 index 0000000000..f3775b4310 --- /dev/null +++ b/examples/cli/host/script.py @@ -0,0 +1,41 @@ +from __future__ import print_function + +import json +import mxnet as mx +from mxnet import gluon + + +def model_fn(model_dir): + """ + Load the gluon model. Called once when hosting service starts. + + :param: model_dir The directory where model files are stored. + :return: a model (in this case a Gluon network) + """ + symbol = mx.sym.load('%s/model.json' % model_dir) + outputs = mx.symbol.softmax(data=symbol, name='softmax_label') + inputs = mx.sym.var('data') + param_dict = gluon.ParameterDict('model_') + net = gluon.SymbolBlock(outputs, inputs, param_dict) + net.load_params('%s/model.params' % model_dir, ctx=mx.cpu()) + return net + + +def transform_fn(net, data, input_content_type, output_content_type): + """ + Transform a request using the Gluon model. Called once per request. + + :param net: The Gluon model. + :param data: The request payload. + :param input_content_type: The request content type. + :param output_content_type: The (desired) response content type. + :return: response payload and content type. + """ + # we can use content types to vary input/output handling, but + # here we just assume json for both + parsed = json.loads(data) + nda = mx.nd.array(parsed) + output = net(nda) + prediction = mx.nd.argmax(output, axis=1) + response_body = json.dumps(prediction.asnumpy().tolist()) + return response_body, output_content_type diff --git a/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz b/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000000..5ace8ea93f Binary files /dev/null and b/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz differ diff --git a/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz b/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..a7e141541c Binary files /dev/null and b/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz differ diff --git a/examples/cli/train/data/training/train-images-idx3-ubyte.gz b/examples/cli/train/data/training/train-images-idx3-ubyte.gz new file mode 100644 index 0000000000..b50e4b6bcc Binary files /dev/null and b/examples/cli/train/data/training/train-images-idx3-ubyte.gz differ diff --git a/examples/cli/train/data/training/train-labels-idx1-ubyte.gz b/examples/cli/train/data/training/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..707a576bb5 Binary files /dev/null and b/examples/cli/train/data/training/train-labels-idx1-ubyte.gz differ diff --git a/examples/cli/train/download_training_data.py b/examples/cli/train/download_training_data.py new file mode 100644 index 0000000000..eb33996904 --- /dev/null +++ b/examples/cli/train/download_training_data.py @@ -0,0 +1,10 @@ +from mxnet import gluon + + +def download_training_data(): + gluon.data.vision.MNIST('./data/training', train=True) + gluon.data.vision.MNIST('./data/training', train=False) + + +if __name__ == "__main__": + download_training_data() diff --git a/examples/cli/train/hyperparameters.json b/examples/cli/train/hyperparameters.json new file mode 100644 index 0000000000..01c3269250 --- /dev/null +++ b/examples/cli/train/hyperparameters.json @@ -0,0 +1,7 @@ +{ + "batch_size": 100, + "epochs": 10, + "learning_rate": 0.1, + "momentum": 0.9, + "log_interval": 100 +} diff --git a/examples/cli/train/run_training_example.sh b/examples/cli/train/run_training_example.sh new file mode 100755 index 0000000000..10176920d4 --- /dev/null +++ b/examples/cli/train/run_training_example.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +python ./download_training_data.py +sagemaker mxnet train --role-name diff --git a/examples/cli/train/script.py b/examples/cli/train/script.py new file mode 100644 index 0000000000..a219548fcc --- /dev/null +++ b/examples/cli/train/script.py @@ -0,0 +1,118 @@ +import logging +import time + +import mxnet as mx +import numpy as np +from mxnet import gluon, autograd +from mxnet.gluon import nn + +logger = logging.getLogger(__name__) + + +def train(channel_input_dirs, hyperparameters, **kwargs): + # SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to + # the current container environment, but here we just use simple cpu context. + ctx = mx.cpu() + + # retrieve the hyperparameters we set in notebook (with some defaults) + batch_size = hyperparameters.get('batch_size', 100) + epochs = hyperparameters.get('epochs', 10) + learning_rate = hyperparameters.get('learning_rate', 0.1) + momentum = hyperparameters.get('momentum', 0.9) + log_interval = hyperparameters.get('log_interval', 100) + + training_data = channel_input_dirs['training'] + + # load training and validation data + # we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic, + # but point it at the location where SageMaker placed the data files, so it doesn't download them again. + train_data = get_train_data(training_data, batch_size) + val_data = get_val_data(training_data, batch_size) + + # define the network + net = define_network() + + # Collect all parameters from net and its children, then initialize them. + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) + # Trainer is for updating parameters with gradient. + trainer = gluon.Trainer(net.collect_params(), 'sgd', + {'learning_rate': learning_rate, 'momentum': momentum}) + metric = mx.metric.Accuracy() + loss = gluon.loss.SoftmaxCrossEntropyLoss() + + for epoch in range(epochs): + # reset data iterator and metric at begining of epoch. + metric.reset() + btic = time.time() + for i, (data, label) in enumerate(train_data): + # Copy data to ctx if necessary + data = data.as_in_context(ctx) + label = label.as_in_context(ctx) + # Start recording computation graph with record() section. + # Recorded graphs can then be differentiated with backward. + with autograd.record(): + output = net(data) + L = loss(output, label) + L.backward() + # take a gradient step with batch_size equal to data.shape[0] + trainer.step(data.shape[0]) + # update metric at last. + metric.update([label], [output]) + + if i % log_interval == 0 and i > 0: + name, acc = metric.get() + logger.info('[Epoch %d Batch %d] Training: %s=%f, %f samples/s' % + (epoch, i, name, acc, batch_size / (time.time() - btic))) + + btic = time.time() + + name, acc = metric.get() + logger.info('[Epoch %d] Training: %s=%f' % (epoch, name, acc)) + + name, val_acc = test(ctx, net, val_data) + logger.info('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc)) + + return net + + +def save(net, model_dir): + # save the model + y = net(mx.sym.var('data')) + y.save('%s/model.json' % model_dir) + net.collect_params().save('%s/model.params' % model_dir) + + +def define_network(): + net = nn.Sequential() + with net.name_scope(): + net.add(nn.Dense(128, activation='relu')) + net.add(nn.Dense(64, activation='relu')) + net.add(nn.Dense(10)) + return net + + +def input_transformer(data, label): + data = data.reshape((-1,)).astype(np.float32) / 255 + return data, label + + +def get_train_data(data_dir, batch_size): + return gluon.data.DataLoader( + gluon.data.vision.MNIST(data_dir, train=True, transform=input_transformer), + batch_size=batch_size, shuffle=True, last_batch='discard') + + +def get_val_data(data_dir, batch_size): + return gluon.data.DataLoader( + gluon.data.vision.MNIST(data_dir, train=False, transform=input_transformer), + batch_size=batch_size, shuffle=False) + + +def test(ctx, net, val_data): + metric = mx.metric.Accuracy() + for data, label in val_data: + data = data.as_in_context(ctx) + label = label.as_in_context(ctx) + output = net(data) + metric.update([label], [output]) + return metric.get() diff --git a/setup.py b/setup.py index fa81d4aaeb..6b048f83fc 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ import os -from setuptools import setup, find_packages from glob import glob from os.path import basename from os.path import splitext +from setuptools import setup, find_packages + def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() @@ -36,4 +37,8 @@ def read(fname): extras_require={ 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock', 'tensorflow>=1.3.0', 'contextlib2']}, + + entry_points={ + 'console_scripts': ['sagemaker=sagemaker.cli.main:main'], + } ) diff --git a/src/sagemaker/cli/__init__.py b/src/sagemaker/cli/__init__.py new file mode 100644 index 0000000000..4d78ed0e5c --- /dev/null +++ b/src/sagemaker/cli/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/src/sagemaker/cli/common.py b/src/sagemaker/cli/common.py new file mode 100644 index 0000000000..80f6fe07ca --- /dev/null +++ b/src/sagemaker/cli/common.py @@ -0,0 +1,112 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import logging +import os +import shutil +import tarfile +import tempfile + +import sagemaker + +logger = logging.getLogger(__name__) + + +class HostCommand(object): + def __init__(self, args): + self.endpoint_name = args.job_name + self.bucket = args.bucket_name # may be None + self.role_name = args.role_name + self.python = args.python + self.data = args.data + self.script = args.script + self.instance_type = args.instance_type + self.instance_count = args.instance_count + self.environment = {k: v for k, v in (kv.split('=') for kv in args.env)} + + self.session = sagemaker.Session() + + def upload_model(self): + prefix = '{}/model'.format(self.endpoint_name) + + archive = self.create_model_archive(self.data) + model_uri = self.session.upload_data(path=archive, bucket=self.bucket, key_prefix=prefix) + shutil.rmtree(os.path.dirname(archive)) + + return model_uri + + @staticmethod + def create_model_archive(src): + if os.path.isdir(src): + arcname = '.' + else: + arcname = os.path.basename(src) + + tmp = tempfile.mkdtemp() + archive = os.path.join(tmp, 'model.tar.gz') + + with tarfile.open(archive, mode='w:gz') as t: + t.add(src, arcname=arcname) + return archive + + def create_model(self, model_url): + raise NotImplementedError # subclasses must override + + def start(self): + model_url = self.upload_model() + model = self.create_model(model_url) + predictor = model.deploy(initial_instance_count=self.instance_count, + instance_type=self.instance_type) + + return predictor + + +class TrainCommand(object): + def __init__(self, args): + self.job_name = args.job_name + self.bucket = args.bucket_name # may be None + self.role_name = args.role_name + self.python = args.python + self.data = args.data + self.script = args.script + self.instance_type = args.instance_type + self.instance_count = args.instance_count + self.hyperparameters = self.load_hyperparameters(args.hyperparameters) + + self.session = sagemaker.Session() + + @staticmethod + def load_hyperparameters(src): + hp = {} + if src and os.path.exists(src): + with open(src, 'r') as f: + hp = json.load(f) + return hp + + def upload_training_data(self): + prefix = '{}/data'.format(self.job_name) + data_url = self.session.upload_data(path=self.data, bucket=self.bucket, key_prefix=prefix) + return data_url + + def create_estimator(self): + raise NotImplementedError # subclasses must override + + def start(self): + data_url = self.upload_training_data() + estimator = self.create_estimator() + estimator.fit(data_url) + logger.debug('code location: {}'.format(estimator.uploaded_code.s3_prefix)) + logger.debug('model location: {}{}/output/model.tar.gz'.format(estimator.output_path, + estimator._current_job_name)) diff --git a/src/sagemaker/cli/main.py b/src/sagemaker/cli/main.py new file mode 100644 index 0000000000..5306757a0f --- /dev/null +++ b/src/sagemaker/cli/main.py @@ -0,0 +1,110 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import argparse +import logging +import sys + +import sagemaker +import sagemaker.cli.mxnet +import sagemaker.cli.tensorflow + +logger = logging.getLogger(__name__) + +DEFAULT_LOG_LEVEL = 'info' +DEFAULT_BOTOCORE_LOG_LEVEL = 'warning' + + +def parse_arguments(args): + parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints') + parser.set_defaults(func=lambda x: parser.print_usage()) + + # common args for training/hosting/all frameworks + common_parser = argparse.ArgumentParser(add_help=False) + common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str, required=True) + common_parser.add_argument('--data', help='path to training data or model files', type=str, default='./data') + common_parser.add_argument('--script', help='path to script', type=str, default='./script.py') + common_parser.add_argument('--job-name', help='job or endpoint name', type=str, default=None) + common_parser.add_argument('--bucket-name', help='S3 bucket for training/model data and script files', + type=str, default=None) + common_parser.add_argument('--python', help='python version', type=str, default='py2') + + instance_group = common_parser.add_argument_group('instance settings') + instance_group.add_argument('--instance-type', type=str, help='instance type', default='ml.m4.xlarge') + instance_group.add_argument('--instance-count', type=int, help='instance count', default=1) + + # common training args + common_train_parser = argparse.ArgumentParser(add_help=False) + common_train_parser.add_argument('--hyperparameters', help='path to training hyperparameters file', + type=str, default='./hyperparameters.json') + + # common hosting args + common_host_parser = argparse.ArgumentParser(add_help=False) + common_host_parser.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[]) + + subparsers = parser.add_subparsers() + + # framework/algo subcommands + mxnet_parser = subparsers.add_parser('mxnet', help='use MXNet', parents=[]) + mxnet_subparsers = mxnet_parser.add_subparsers() + mxnet_train_parser = mxnet_subparsers.add_parser('train', + help='start a training job', + parents=[common_parser, common_train_parser]) + mxnet_train_parser.set_defaults(func=sagemaker.cli.mxnet.train) + + mxnet_host_parser = mxnet_subparsers.add_parser('host', + help='start a hosting endpoint', + parents=[common_parser, common_host_parser]) + mxnet_host_parser.set_defaults(func=sagemaker.cli.mxnet.host) + + tensorflow_parser = subparsers.add_parser('tensorflow', help='use TensorFlow', parents=[]) + tensorflow_subparsers = tensorflow_parser.add_subparsers() + tensorflow_train_parser = tensorflow_subparsers.add_parser('train', + help='start a training job', + parents=[common_parser, common_train_parser]) + tensorflow_train_parser.add_argument('--training-steps', + help='number of training steps (tensorflow only)', type=int, default=None) + tensorflow_train_parser.add_argument('--evaluation-steps', + help='number of evaluation steps (tensorflow only)', type=int, default=None) + tensorflow_train_parser.set_defaults(func=sagemaker.cli.tensorflow.train) + + tensorflow_host_parser = tensorflow_subparsers.add_parser('host', + help='start a hosting endpoint', + parents=[common_parser, common_host_parser]) + tensorflow_host_parser.set_defaults(func=sagemaker.cli.tensorflow.host) + + log_group = parser.add_argument_group('optional log settings') + log_group.add_argument('--log-level', help='log level for this command', type=str, default=DEFAULT_LOG_LEVEL) + log_group.add_argument('--botocore-log-level', help='log level for botocore', type=str, + default=DEFAULT_BOTOCORE_LOG_LEVEL) + + return parser.parse_args(args) + + +def configure_logging(args): + log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' + log_level = logging.getLevelName(args.log_level.upper()) + logging.basicConfig(format=log_format, level=log_level) + logging.getLogger("botocore").setLevel(args.botocore_log_level.upper()) + + +def main(): + args = parse_arguments(sys.argv[1:]) + configure_logging(args) + logger.debug('args: {}'.format(args)) + args.func(args) + + +if __name__ == '__main__': + main() diff --git a/src/sagemaker/cli/mxnet.py b/src/sagemaker/cli/mxnet.py new file mode 100644 index 0000000000..fa7f2c2a7f --- /dev/null +++ b/src/sagemaker/cli/mxnet.py @@ -0,0 +1,46 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.cli.common import HostCommand, TrainCommand + + +def train(args): + MXNetTrainCommand(args).start() + + +def host(args): + MXNetHostCommand(args).start() + + +class MXNetTrainCommand(TrainCommand): + def __init__(self, args): + super(MXNetTrainCommand, self).__init__(args) + + def create_estimator(self): + from sagemaker.mxnet.estimator import MXNet + return MXNet(self.script, + role=self.role_name, + base_job_name=self.job_name, + train_instance_count=self.instance_count, + train_instance_type=self.instance_type, + hyperparameters=self.hyperparameters, + py_version=self.python) + + +class MXNetHostCommand(HostCommand): + def __init__(self, args): + super(MXNetHostCommand, self).__init__(args) + + def create_model(self, model_url): + from sagemaker.mxnet.model import MXNetModel + return MXNetModel(model_data=model_url, role=self.role_name, entry_point=self.script, + py_version=self.python, name=self.endpoint_name, env=self.environment) diff --git a/src/sagemaker/cli/tensorflow.py b/src/sagemaker/cli/tensorflow.py new file mode 100644 index 0000000000..231fa96e83 --- /dev/null +++ b/src/sagemaker/cli/tensorflow.py @@ -0,0 +1,50 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.cli.common import HostCommand, TrainCommand + + +def train(args): + TensorFlowTrainCommand(args).start() + + +def host(args): + TensorFlowHostCommand(args).start() + + +class TensorFlowTrainCommand(TrainCommand): + def __init__(self, args): + super(TensorFlowTrainCommand, self).__init__(args) + self.training_steps = args.training_steps + self.evaluation_steps = args.evaluation_steps + + def create_estimator(self): + from sagemaker.tensorflow import TensorFlow + return TensorFlow(training_steps=self.training_steps, + evaluation_steps=self.evaluation_steps, + py_version=self.python, + entry_point=self.script, + role=self.role_name, + base_job_name=self.job_name, + train_instance_count=self.instance_count, + train_instance_type=self.instance_type, + hyperparameters=self.hyperparameters) + + +class TensorFlowHostCommand(HostCommand): + def __init__(self, args): + super(TensorFlowHostCommand, self).__init__(args) + + def create_model(self, model_url): + from sagemaker.tensorflow.model import TensorFlowModel + return TensorFlowModel(model_data=model_url, role=self.role_name, entry_point=self.script, + py_version=self.python, name=self.endpoint_name, env=self.environment) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000000..1956e95ba5 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,193 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest +import sagemaker.cli.main as cli +from mock import patch + +COMMON_ARGS = '--role-name myrole --data mydata --script myscript --job-name myjob --bucket-name mybucket ' + \ + '--python py3 --instance-type myinstance --instance-count 2' + +TRAIN_ARGS = '--hyperparameters myhyperparameters.json' + +LOG_ARGS = '--log-level debug --botocore-log-level debug' + +HOST_ARGS = '--env ENV1=env1 ENV2=env2' + + +def assert_common_defaults(args): + assert args.data == './data' + assert args.script == './script.py' + assert args.job_name is None + assert args.bucket_name is None + assert args.python == 'py2' + assert args.instance_type == 'ml.m4.xlarge' + assert args.instance_count == 1 + assert args.log_level == 'info' + assert args.botocore_log_level == 'warning' + + +def assert_common_non_defaults(args): + assert args.data == 'mydata' + assert args.script == 'myscript' + assert args.job_name == 'myjob' + assert args.bucket_name == 'mybucket' + assert args.role_name == 'myrole' + assert args.python == 'py3' + assert args.instance_type == 'myinstance' + assert args.instance_count == 2 + assert args.log_level == 'debug' + assert args.botocore_log_level == 'debug' + + +def assert_train_defaults(args): + assert args.hyperparameters == './hyperparameters.json' + + +def assert_train_non_defaults(args): + assert args.hyperparameters == 'myhyperparameters.json' + + +def assert_host_defaults(args): + assert args.env == [] + + +def assert_host_non_defaults(args): + assert args.env == ['ENV1=env1', 'ENV2=env2'] + + +def test_args_mxnet_train_defaults(): + args = cli.parse_arguments('mxnet train --role-name role'.split()) + assert_common_defaults(args) + assert_train_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.mxnet' + assert args.func.__name__ == 'train' + + +def test_args_mxnet_train_non_defaults(): + args = cli.parse_arguments('{} mxnet train --role-name role {} {}' + .format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS) + .split()) + assert_common_non_defaults(args) + assert_train_non_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.mxnet' + assert args.func.__name__ == 'train' + + +def test_args_mxnet_host_defaults(): + args = cli.parse_arguments('mxnet host --role-name role'.split()) + assert_common_defaults(args) + assert_host_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.mxnet' + assert args.func.__name__ == 'host' + + +def test_args_mxnet_host_non_defaults(): + args = cli.parse_arguments('{} mxnet host --role-name role {} {}' + .format(LOG_ARGS, COMMON_ARGS, HOST_ARGS) + .split()) + assert_common_non_defaults(args) + assert_host_non_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.mxnet' + assert args.func.__name__ == 'host' + + +def test_args_tensorflow_train_defaults(): + args = cli.parse_arguments('tensorflow train --role-name role'.split()) + assert_common_defaults(args) + assert_train_defaults(args) + assert args.training_steps is None + assert args.evaluation_steps is None + assert args.func.__module__ == 'sagemaker.cli.tensorflow' + assert args.func.__name__ == 'train' + + +def test_args_tensorflow_train_non_defaults(): + args = cli.parse_arguments('{} tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {}' + .format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS) + .split()) + assert_common_non_defaults(args) + assert_train_non_defaults(args) + assert args.training_steps == 10 + assert args.evaluation_steps == 5 + assert args.func.__module__ == 'sagemaker.cli.tensorflow' + assert args.func.__name__ == 'train' + + +def test_args_tensorflow_host_defaults(): + args = cli.parse_arguments('tensorflow host --role-name role'.split()) + assert_common_defaults(args) + assert_host_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.tensorflow' + assert args.func.__name__ == 'host' + + +def test_args_tensorflow_host_non_defaults(): + args = cli.parse_arguments('{} tensorflow host --role-name role {} {}' + .format(LOG_ARGS, COMMON_ARGS, HOST_ARGS) + .split()) + assert_common_non_defaults(args) + assert_host_non_defaults(args) + assert args.func.__module__ == 'sagemaker.cli.tensorflow' + assert args.func.__name__ == 'host' + + +def test_args_invalid_framework(): + with pytest.raises(SystemExit): + cli.parse_arguments('fakeframework train --role-name role'.split()) + + +def test_args_invalid_subcommand(): + with pytest.raises(SystemExit): + cli.parse_arguments('mxnet drain'.split()) + + +def test_args_invalid_args(): + with pytest.raises(SystemExit): + cli.parse_arguments('tensorflow train --role-name role --notdata foo'.split()) + + +def test_args_invalid_mxnet_python(): + with pytest.raises(SystemExit): + cli.parse_arguments('mxnet train --role-name role nython py2'.split()) + + +def test_args_invalid_host_args_in_train(): + with pytest.raises(SystemExit): + cli.parse_arguments('mxnet train --role-name role --env FOO=bar'.split()) + + +def test_args_invalid_train_args_in_host(): + with pytest.raises(SystemExit): + cli.parse_arguments('tensorflow host --role-name role --hyperparameters foo.json'.split()) + + +@patch('sagemaker.mxnet.estimator.MXNet') +@patch('sagemaker.Session') +def test_mxnet_train(session, estimator): + args = cli.parse_arguments('mxnet train --role-name role'.split()) + args.func(args) + session.return_value.upload_data.assert_called() + estimator.assert_called() + estimator.return_value.fit.assert_called() + + +@patch('sagemaker.mxnet.model.MXNetModel') +@patch('sagemaker.cli.common.HostCommand.upload_model') +@patch('sagemaker.Session') +def test_mxnet_host(session, upload_model, model): + args = cli.parse_arguments('mxnet host --role-name role'.split()) + args.func(args) + session.assert_called() + upload_model.assert_called() + model.assert_called() + model.return_value.deploy.assert_called()