diff --git a/examples/cli/host/data/model.json b/examples/cli/host/data/model.json deleted file mode 100644 index 8d0d57a51e..0000000000 --- a/examples/cli/host/data/model.json +++ /dev/null @@ -1,126 +0,0 @@ -{ - "nodes": [ - { - "op": "null", - "name": "data", - "inputs": [] - }, - { - "op": "null", - "name": "sequential0_dense0_weight", - "attr": { - "__dtype__": "0", - "__lr_mult__": "1.0", - "__shape__": "(128, 0)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "null", - "name": "sequential0_dense0_bias", - "attr": { - "__dtype__": "0", - "__init__": "zeros", - "__lr_mult__": "1.0", - "__shape__": "(128,)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "FullyConnected", - "name": "sequential0_dense0_fwd", - "attr": {"num_hidden": "128"}, - "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] - }, - { - "op": "Activation", - "name": "sequential0_dense0_relu_fwd", - "attr": {"act_type": "relu"}, - "inputs": [[3, 0, 0]] - }, - { - "op": "null", - "name": "sequential0_dense1_weight", - "attr": { - "__dtype__": "0", - "__lr_mult__": "1.0", - "__shape__": "(64, 0)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "null", - "name": "sequential0_dense1_bias", - "attr": { - "__dtype__": "0", - "__init__": "zeros", - "__lr_mult__": "1.0", - "__shape__": "(64,)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "FullyConnected", - "name": "sequential0_dense1_fwd", - "attr": {"num_hidden": "64"}, - "inputs": [[4, 0, 0], [5, 0, 0], [6, 0, 0]] - }, - { - "op": "Activation", - "name": "sequential0_dense1_relu_fwd", - "attr": {"act_type": "relu"}, - "inputs": [[7, 0, 0]] - }, - { - "op": "null", - "name": "sequential0_dense2_weight", - "attr": { - "__dtype__": "0", - "__lr_mult__": "1.0", - "__shape__": "(10, 0)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "null", - "name": "sequential0_dense2_bias", - "attr": { - "__dtype__": "0", - "__init__": "zeros", - "__lr_mult__": "1.0", - "__shape__": "(10,)", - "__wd_mult__": "1.0" - }, - "inputs": [] - }, - { - "op": "FullyConnected", - "name": "sequential0_dense2_fwd", - "attr": {"num_hidden": "10"}, - "inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0]] - } - ], - "arg_nodes": [0, 1, 2, 5, 6, 9, 10], - "node_row_ptr": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12 - ], - "heads": [[11, 0, 0]], - "attrs": {"mxnet_version": ["int", 1100]} -} \ No newline at end of file diff --git a/examples/cli/host/data/model.params b/examples/cli/host/data/model.params deleted file mode 100644 index 3757d543c8..0000000000 Binary files a/examples/cli/host/data/model.params and /dev/null differ diff --git a/examples/cli/host/run_hosting_example.sh b/examples/cli/host/run_hosting_example.sh deleted file mode 100644 index b6d7e92d4d..0000000000 --- a/examples/cli/host/run_hosting_example.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -sagemaker mxnet host --role-name diff --git a/examples/cli/host/script.py b/examples/cli/host/script.py deleted file mode 100644 index a5a549a04b..0000000000 --- a/examples/cli/host/script.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import print_function - -import json -import mxnet as mx -from mxnet import gluon - - -def model_fn(model_dir): - """Load the gluon model. Called once when hosting service starts. - - Args: - model_dir: The directory where model files are stored. - - Returns: - a model (in this case a Gluon network) - """ - symbol = mx.sym.load("%s/model.json" % model_dir) - outputs = mx.symbol.softmax(data=symbol, name="softmax_label") - inputs = mx.sym.var("data") - param_dict = gluon.ParameterDict("model_") - net = gluon.SymbolBlock(outputs, inputs, param_dict) - net.load_params("%s/model.params" % model_dir, ctx=mx.cpu()) - return net - - -def transform_fn(net, data, input_content_type, output_content_type): - """Transform a request using the Gluon model. Called once per request. - - Args: - net: The Gluon model. - data: The request payload. - input_content_type: The request content type. - output_content_type: The (desired) response content type. - - Returns: - response payload and content type. - """ - # we can use content types to vary input/output handling, but - # here we just assume json for both - parsed = json.loads(data) - nda = mx.nd.array(parsed) - output = net(nda) - prediction = mx.nd.argmax(output, axis=1) - response_body = json.dumps(prediction.asnumpy().tolist()) - return response_body, output_content_type diff --git a/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz b/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz deleted file mode 100644 index 5ace8ea93f..0000000000 Binary files a/examples/cli/train/data/training/t10k-images-idx3-ubyte.gz and /dev/null differ diff --git a/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz b/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz deleted file mode 100644 index a7e141541c..0000000000 Binary files a/examples/cli/train/data/training/t10k-labels-idx1-ubyte.gz and /dev/null differ diff --git a/examples/cli/train/data/training/train-images-idx3-ubyte.gz b/examples/cli/train/data/training/train-images-idx3-ubyte.gz deleted file mode 100644 index b50e4b6bcc..0000000000 Binary files a/examples/cli/train/data/training/train-images-idx3-ubyte.gz and /dev/null differ diff --git a/examples/cli/train/data/training/train-labels-idx1-ubyte.gz b/examples/cli/train/data/training/train-labels-idx1-ubyte.gz deleted file mode 100644 index 707a576bb5..0000000000 Binary files a/examples/cli/train/data/training/train-labels-idx1-ubyte.gz and /dev/null differ diff --git a/examples/cli/train/download_training_data.py b/examples/cli/train/download_training_data.py deleted file mode 100644 index 2bc97d9588..0000000000 --- a/examples/cli/train/download_training_data.py +++ /dev/null @@ -1,10 +0,0 @@ -from mxnet import gluon - - -def download_training_data(): - gluon.data.vision.MNIST("./data/training", train=True) - gluon.data.vision.MNIST("./data/training", train=False) - - -if __name__ == "__main__": - download_training_data() diff --git a/examples/cli/train/hyperparameters.json b/examples/cli/train/hyperparameters.json deleted file mode 100644 index 01c3269250..0000000000 --- a/examples/cli/train/hyperparameters.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "batch_size": 100, - "epochs": 10, - "learning_rate": 0.1, - "momentum": 0.9, - "log_interval": 100 -} diff --git a/examples/cli/train/run_training_example.sh b/examples/cli/train/run_training_example.sh deleted file mode 100755 index 10176920d4..0000000000 --- a/examples/cli/train/run_training_example.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -python ./download_training_data.py -sagemaker mxnet train --role-name diff --git a/examples/cli/train/script.py b/examples/cli/train/script.py deleted file mode 100644 index d97a364f85..0000000000 --- a/examples/cli/train/script.py +++ /dev/null @@ -1,158 +0,0 @@ -import logging -import time - -import mxnet as mx -import numpy as np -from mxnet import gluon, autograd -from mxnet.gluon import nn - -logger = logging.getLogger(__name__) - - -def train(channel_input_dirs, hyperparameters, **kwargs): - # SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to - # the current container environment, but here we just use simple cpu context. - """ - Args: - channel_input_dirs: - hyperparameters: - **kwargs: - """ - ctx = mx.cpu() - - # retrieve the hyperparameters we set in notebook (with some defaults) - batch_size = hyperparameters.get("batch_size", 100) - epochs = hyperparameters.get("epochs", 10) - learning_rate = hyperparameters.get("learning_rate", 0.1) - momentum = hyperparameters.get("momentum", 0.9) - log_interval = hyperparameters.get("log_interval", 100) - - training_data = channel_input_dirs["training"] - - # load training and validation data - # we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic, - # but point it at the location where SageMaker placed the data files, so it doesn't download them again. - train_data = get_train_data(training_data, batch_size) - val_data = get_val_data(training_data, batch_size) - - # define the network - net = define_network() - - # Collect all parameters from net and its children, then initialize them. - net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) - # Trainer is for updating parameters with gradient. - trainer = gluon.Trainer( - net.collect_params(), "sgd", {"learning_rate": learning_rate, "momentum": momentum} - ) - metric = mx.metric.Accuracy() - loss = gluon.loss.SoftmaxCrossEntropyLoss() - - for epoch in range(epochs): - # reset data iterator and metric at begining of epoch. - metric.reset() - btic = time.time() - for i, (data, label) in enumerate(train_data): - # Copy data to ctx if necessary - data = data.as_in_context(ctx) - label = label.as_in_context(ctx) - # Start recording computation graph with record() section. - # Recorded graphs can then be differentiated with backward. - with autograd.record(): - output = net(data) - L = loss(output, label) - L.backward() - # take a gradient step with batch_size equal to data.shape[0] - trainer.step(data.shape[0]) - # update metric at last. - metric.update([label], [output]) - - if i % log_interval == 0 and i > 0: - name, acc = metric.get() - logger.info( - "[Epoch %d Batch %d] Training: %s=%f, %f samples/s" - % (epoch, i, name, acc, batch_size / (time.time() - btic)) - ) - - btic = time.time() - - name, acc = metric.get() - logger.info("[Epoch %d] Training: %s=%f" % (epoch, name, acc)) - - name, val_acc = test(ctx, net, val_data) - logger.info("[Epoch %d] Validation: %s=%f" % (epoch, name, val_acc)) - - return net - - -def save(net, model_dir): - # save the model - """ - Args: - net: - model_dir: - """ - y = net(mx.sym.var("data")) - y.save("%s/model.json" % model_dir) - net.collect_params().save("%s/model.params" % model_dir) - - -def define_network(): - net = nn.Sequential() - with net.name_scope(): - net.add(nn.Dense(128, activation="relu")) - net.add(nn.Dense(64, activation="relu")) - net.add(nn.Dense(10)) - return net - - -def input_transformer(data, label): - """ - Args: - data: - label: - """ - data = data.reshape((-1,)).astype(np.float32) / 255 - return data, label - - -def get_train_data(data_dir, batch_size): - """ - Args: - data_dir: - batch_size: - """ - return gluon.data.DataLoader( - gluon.data.vision.MNIST(data_dir, train=True, transform=input_transformer), - batch_size=batch_size, - shuffle=True, - last_batch="discard", - ) - - -def get_val_data(data_dir, batch_size): - """ - Args: - data_dir: - batch_size: - """ - return gluon.data.DataLoader( - gluon.data.vision.MNIST(data_dir, train=False, transform=input_transformer), - batch_size=batch_size, - shuffle=False, - ) - - -def test(ctx, net, val_data): - """ - Args: - ctx: - net: - val_data: - """ - metric = mx.metric.Accuracy() - for data, label in val_data: - data = data.as_in_context(ctx) - label = label.as_in_context(ctx) - output = net(data) - metric.update([label], [output]) - return metric.get() diff --git a/setup.py b/setup.py index 7cc1c766de..bb007ff82c 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,6 @@ def read_version(): extras_require=extras, entry_points={ "console_scripts": [ - "sagemaker=sagemaker.cli.main:main", "sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main", ] }, diff --git a/src/sagemaker/cli/__init__.py b/src/sagemaker/cli/__init__.py deleted file mode 100644 index 77f4efcba5..0000000000 --- a/src/sagemaker/cli/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. diff --git a/src/sagemaker/cli/common.py b/src/sagemaker/cli/common.py deleted file mode 100644 index 7256937bc7..0000000000 --- a/src/sagemaker/cli/common.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -import json -import logging -import os -import shutil -import tarfile -import tempfile - -import sagemaker - -logger = logging.getLogger(__name__) - - -class HostCommand(object): - """Placeholder docstring""" - - def __init__(self, args): - """ - Args: - args: - """ - self.endpoint_name = args.job_name - self.bucket = args.bucket_name # may be None - self.role_name = args.role_name - self.python = args.python - self.data = args.data - self.script = args.script - self.instance_type = args.instance_type - self.instance_count = args.instance_count - self.environment = dict((kv.split("=") for kv in args.env)) - - self.session = sagemaker.Session() - - def upload_model(self): - """Placeholder docstring""" - prefix = "{}/model".format(self.endpoint_name) - - archive = self.create_model_archive(self.data) - model_uri = self.session.upload_data(path=archive, bucket=self.bucket, key_prefix=prefix) - shutil.rmtree(os.path.dirname(archive)) - - return model_uri - - @staticmethod - def create_model_archive(src): - """ - Args: - src: - """ - if os.path.isdir(src): - arcname = "." - else: - arcname = os.path.basename(src) - - tmp = tempfile.mkdtemp() - archive = os.path.join(tmp, "model.tar.gz") - - with tarfile.open(archive, mode="w:gz") as t: - t.add(src, arcname=arcname) - return archive - - def create_model(self, model_url): - """ - Args: - model_url: - """ - raise NotImplementedError # subclasses must override - - def start(self): - """Placeholder docstring""" - model_url = self.upload_model() - model = self.create_model(model_url) - predictor = model.deploy( - initial_instance_count=self.instance_count, instance_type=self.instance_type - ) - - return predictor - - -class TrainCommand(object): - """Placeholder docstring""" - - def __init__(self, args): - """ - Args: - args: - """ - self.job_name = args.job_name - self.bucket = args.bucket_name # may be None - self.role_name = args.role_name - self.python = args.python - self.data = args.data - self.script = args.script - self.instance_type = args.instance_type - self.instance_count = args.instance_count - self.hyperparameters = self.load_hyperparameters(args.hyperparameters) - - self.session = sagemaker.Session() - - @staticmethod - def load_hyperparameters(src): - """ - Args: - src: - """ - hp = {} - if src and os.path.exists(src): - with open(src, "r") as f: - hp = json.load(f) - return hp - - def upload_training_data(self): - """Placeholder docstring""" - prefix = "{}/data".format(self.job_name) - data_url = self.session.upload_data(path=self.data, bucket=self.bucket, key_prefix=prefix) - return data_url - - def create_estimator(self): - """Placeholder docstring""" - raise NotImplementedError # subclasses must override - - def start(self): - """Placeholder docstring""" - data_url = self.upload_training_data() - estimator = self.create_estimator() - estimator.fit(data_url) - logger.debug("code location: %s", estimator.uploaded_code.s3_prefix) - logger.debug( - "model location: %s%s/output/model.tar.gz", - estimator.output_path, - estimator._current_job_name, - ) diff --git a/src/sagemaker/cli/main.py b/src/sagemaker/cli/main.py deleted file mode 100644 index 5328596ba6..0000000000 --- a/src/sagemaker/cli/main.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -import argparse -import logging -import sys - -import sagemaker -import sagemaker.cli.mxnet -import sagemaker.cli.tensorflow - -logger = logging.getLogger(__name__) - -DEFAULT_LOG_LEVEL = "info" -DEFAULT_BOTOCORE_LOG_LEVEL = "warning" - - -def parse_arguments(args): - """ - Args: - args: - """ - parser = argparse.ArgumentParser( - description="Launch SageMaker training jobs or hosting endpoints" - ) - parser.set_defaults(func=lambda x: parser.print_usage()) - - # common args for training/hosting/all frameworks - common_parser = argparse.ArgumentParser(add_help=False) - common_parser.add_argument( - "--role-name", help="SageMaker execution role name", type=str, required=True - ) - common_parser.add_argument( - "--data", help="path to training data or model files", type=str, default="./data" - ) - common_parser.add_argument("--script", help="path to script", type=str, default="./script.py") - common_parser.add_argument("--job-name", help="job or endpoint name", type=str, default=None) - common_parser.add_argument( - "--bucket-name", - help="S3 bucket for training/model data and script files", - type=str, - default=None, - ) - common_parser.add_argument("--python", help="python version", type=str, default="py2") - - instance_group = common_parser.add_argument_group("instance settings") - instance_group.add_argument( - "--instance-type", type=str, help="instance type", default="ml.m4.xlarge" - ) - instance_group.add_argument("--instance-count", type=int, help="instance count", default=1) - - # common training args - common_train_parser = argparse.ArgumentParser(add_help=False) - common_train_parser.add_argument( - "--hyperparameters", - help="path to training hyperparameters file", - type=str, - default="./hyperparameters.json", - ) - - # common hosting args - common_host_parser = argparse.ArgumentParser(add_help=False) - common_host_parser.add_argument( - "--env", help="hosting environment variable(s)", type=str, nargs="*", default=[] - ) - - subparsers = parser.add_subparsers() - - # framework/algo subcommands - mxnet_parser = subparsers.add_parser("mxnet", help="use MXNet", parents=[]) - mxnet_subparsers = mxnet_parser.add_subparsers() - mxnet_train_parser = mxnet_subparsers.add_parser( - "train", help="start a training job", parents=[common_parser, common_train_parser] - ) - mxnet_train_parser.set_defaults(func=sagemaker.cli.mxnet.train) - - mxnet_host_parser = mxnet_subparsers.add_parser( - "host", help="start a hosting endpoint", parents=[common_parser, common_host_parser] - ) - mxnet_host_parser.set_defaults(func=sagemaker.cli.mxnet.host) - - tensorflow_parser = subparsers.add_parser("tensorflow", help="use TensorFlow", parents=[]) - tensorflow_subparsers = tensorflow_parser.add_subparsers() - tensorflow_train_parser = tensorflow_subparsers.add_parser( - "train", help="start a training job", parents=[common_parser, common_train_parser] - ) - tensorflow_train_parser.add_argument( - "--training-steps", - help="number of training steps (tensorflow only)", - type=int, - default=None, - ) - tensorflow_train_parser.add_argument( - "--evaluation-steps", - help="number of evaluation steps (tensorflow only)", - type=int, - default=None, - ) - tensorflow_train_parser.set_defaults(func=sagemaker.cli.tensorflow.train) - - tensorflow_host_parser = tensorflow_subparsers.add_parser( - "host", help="start a hosting endpoint", parents=[common_parser, common_host_parser] - ) - tensorflow_host_parser.set_defaults(func=sagemaker.cli.tensorflow.host) - - log_group = parser.add_argument_group("optional log settings") - log_group.add_argument( - "--log-level", help="log level for this command", type=str, default=DEFAULT_LOG_LEVEL - ) - log_group.add_argument( - "--botocore-log-level", - help="log level for botocore", - type=str, - default=DEFAULT_BOTOCORE_LOG_LEVEL, - ) - - return parser.parse_args(args) - - -def configure_logging(args): - """ - Args: - args: - """ - log_format = "%(asctime)s %(levelname)s %(name)s: %(message)s" - log_level = logging.getLevelName(args.log_level.upper()) - logging.basicConfig(format=log_format, level=log_level) - logging.getLogger("botocore").setLevel(args.botocore_log_level.upper()) - - -def main(): - """Placeholder docstring""" - args = parse_arguments(sys.argv[1:]) - configure_logging(args) - logger.debug("args: %s", args) - args.func(args) - - -if __name__ == "__main__": - main() diff --git a/src/sagemaker/cli/mxnet.py b/src/sagemaker/cli/mxnet.py deleted file mode 100644 index e7548a4178..0000000000 --- a/src/sagemaker/cli/mxnet.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -from sagemaker.cli.common import HostCommand, TrainCommand - -MXNET_VERSION = "1.2" - - -def train(args): - """ - Args: - args: - """ - MXNetTrainCommand(args).start() - - -def host(args): - """ - Args: - args: - """ - MXNetHostCommand(args).start() - - -class MXNetTrainCommand(TrainCommand): - """Placeholder docstring""" - - def create_estimator(self): - """Placeholder docstring""" - from sagemaker.mxnet.estimator import MXNet - - return MXNet( - entry_point=self.script, - framework_version=MXNET_VERSION, - py_version=self.python, - role=self.role_name, - base_job_name=self.job_name, - instance_count=self.instance_count, - instance_type=self.instance_type, - hyperparameters=self.hyperparameters, - ) - - -class MXNetHostCommand(HostCommand): - """Placeholder docstring""" - - def create_model(self, model_url): - """ - Args: - model_url: - """ - from sagemaker.mxnet.model import MXNetModel - - return MXNetModel( - model_data=model_url, - role=self.role_name, - entry_point=self.script, - framework_version=MXNET_VERSION, - py_version=self.python, - name=self.endpoint_name, - env=self.environment, - ) diff --git a/src/sagemaker/cli/tensorflow.py b/src/sagemaker/cli/tensorflow.py deleted file mode 100644 index 008e08ef53..0000000000 --- a/src/sagemaker/cli/tensorflow.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -from sagemaker.cli.common import HostCommand, TrainCommand - - -def train(args): - """ - Args: - args: - """ - TensorFlowTrainCommand(args).start() - - -def host(args): - """ - Args: - args: - """ - TensorFlowHostCommand(args).start() - - -class TensorFlowTrainCommand(TrainCommand): - """Placeholder docstring""" - - def __init__(self, args): - """ - Args: - args: - """ - super(TensorFlowTrainCommand, self).__init__(args) - self.training_steps = args.training_steps - self.evaluation_steps = args.evaluation_steps - - def create_estimator(self): - from sagemaker.tensorflow import TensorFlow - - return TensorFlow( - training_steps=self.training_steps, - evaluation_steps=self.evaluation_steps, - py_version=self.python, - entry_point=self.script, - role=self.role_name, - base_job_name=self.job_name, - instance_count=self.instance_count, - instance_type=self.instance_type, - hyperparameters=self.hyperparameters, - ) - - -class TensorFlowHostCommand(HostCommand): - """Placeholder docstring""" - - def create_model(self, model_url): - """ - Args: - model_url: - """ - from sagemaker.tensorflow.model import TensorFlowModel - - return TensorFlowModel( - model_data=model_url, - role=self.role_name, - entry_point=self.script, - name=self.endpoint_name, - env=self.environment, - ) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py deleted file mode 100644 index 28ae7e453a..0000000000 --- a/tests/unit/test_cli.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from __future__ import absolute_import - -import pytest -import sagemaker.cli.main as cli -from mock import patch - -COMMON_ARGS = ( - "--role-name myrole --data mydata --script myscript --job-name myjob --bucket-name mybucket " - + "--python py3 --instance-type myinstance --instance-count 2" -) - -TRAIN_ARGS = "--hyperparameters myhyperparameters.json" - -LOG_ARGS = "--log-level debug --botocore-log-level debug" - -HOST_ARGS = "--env ENV1=env1 ENV2=env2" - - -def assert_common_defaults(args): - assert args.data == "./data" - assert args.script == "./script.py" - assert args.job_name is None - assert args.bucket_name is None - assert args.python == "py2" - assert args.instance_type == "ml.m4.xlarge" - assert args.instance_count == 1 - assert args.log_level == "info" - assert args.botocore_log_level == "warning" - - -def assert_common_non_defaults(args): - assert args.data == "mydata" - assert args.script == "myscript" - assert args.job_name == "myjob" - assert args.bucket_name == "mybucket" - assert args.role_name == "myrole" - assert args.python == "py3" - assert args.instance_type == "myinstance" - assert args.instance_count == 2 - assert args.log_level == "debug" - assert args.botocore_log_level == "debug" - - -def assert_train_defaults(args): - assert args.hyperparameters == "./hyperparameters.json" - - -def assert_train_non_defaults(args): - assert args.hyperparameters == "myhyperparameters.json" - - -def assert_host_defaults(args): - assert args.env == [] - - -def assert_host_non_defaults(args): - assert args.env == ["ENV1=env1", "ENV2=env2"] - - -def test_args_mxnet_train_defaults(): - args = cli.parse_arguments("mxnet train --role-name role".split()) - assert_common_defaults(args) - assert_train_defaults(args) - assert args.func.__module__ == "sagemaker.cli.mxnet" - assert args.func.__name__ == "train" - - -def test_args_mxnet_train_non_defaults(): - args = cli.parse_arguments( - "{} mxnet train --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS).split() - ) - assert_common_non_defaults(args) - assert_train_non_defaults(args) - assert args.func.__module__ == "sagemaker.cli.mxnet" - assert args.func.__name__ == "train" - - -def test_args_mxnet_host_defaults(): - args = cli.parse_arguments("mxnet host --role-name role".split()) - assert_common_defaults(args) - assert_host_defaults(args) - assert args.func.__module__ == "sagemaker.cli.mxnet" - assert args.func.__name__ == "host" - - -def test_args_mxnet_host_non_defaults(): - args = cli.parse_arguments( - "{} mxnet host --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, HOST_ARGS).split() - ) - assert_common_non_defaults(args) - assert_host_non_defaults(args) - assert args.func.__module__ == "sagemaker.cli.mxnet" - assert args.func.__name__ == "host" - - -def test_args_tensorflow_train_defaults(): - args = cli.parse_arguments("tensorflow train --role-name role".split()) - assert_common_defaults(args) - assert_train_defaults(args) - assert args.training_steps is None - assert args.evaluation_steps is None - assert args.func.__module__ == "sagemaker.cli.tensorflow" - assert args.func.__name__ == "train" - - -def test_args_tensorflow_train_non_defaults(): - args = cli.parse_arguments( - "{} tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {}".format( - LOG_ARGS, COMMON_ARGS, TRAIN_ARGS - ).split() - ) - assert_common_non_defaults(args) - assert_train_non_defaults(args) - assert args.training_steps == 10 - assert args.evaluation_steps == 5 - assert args.func.__module__ == "sagemaker.cli.tensorflow" - assert args.func.__name__ == "train" - - -def test_args_tensorflow_host_defaults(): - args = cli.parse_arguments("tensorflow host --role-name role".split()) - assert_common_defaults(args) - assert_host_defaults(args) - assert args.func.__module__ == "sagemaker.cli.tensorflow" - assert args.func.__name__ == "host" - - -def test_args_tensorflow_host_non_defaults(): - args = cli.parse_arguments( - "{} tensorflow host --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, HOST_ARGS).split() - ) - assert_common_non_defaults(args) - assert_host_non_defaults(args) - assert args.func.__module__ == "sagemaker.cli.tensorflow" - assert args.func.__name__ == "host" - - -def test_args_invalid_framework(): - with pytest.raises(SystemExit): - cli.parse_arguments("fakeframework train --role-name role".split()) - - -def test_args_invalid_subcommand(): - with pytest.raises(SystemExit): - cli.parse_arguments("mxnet drain".split()) - - -def test_args_invalid_args(): - with pytest.raises(SystemExit): - cli.parse_arguments("tensorflow train --role-name role --notdata foo".split()) - - -def test_args_invalid_mxnet_python(): - with pytest.raises(SystemExit): - cli.parse_arguments("mxnet train --role-name role nython py2".split()) - - -def test_args_invalid_host_args_in_train(): - with pytest.raises(SystemExit): - cli.parse_arguments("mxnet train --role-name role --env FOO=bar".split()) - - -def test_args_invalid_train_args_in_host(): - with pytest.raises(SystemExit): - cli.parse_arguments("tensorflow host --role-name role --hyperparameters foo.json".split()) - - -@patch("sagemaker.mxnet.estimator.MXNet") -@patch("sagemaker.Session") -def test_mxnet_train(session, estimator): - args = cli.parse_arguments("mxnet train --role-name role".split()) - args.func(args) - session.return_value.upload_data.assert_called() - estimator.assert_called() - estimator.return_value.fit.assert_called() - - -@patch("sagemaker.mxnet.model.MXNetModel") -@patch("sagemaker.cli.common.HostCommand.upload_model") -@patch("sagemaker.Session") -def test_mxnet_host(session, upload_model, model): - args = cli.parse_arguments("mxnet host --role-name role".split()) - args.func(args) - session.assert_called() - upload_model.assert_called() - model.assert_called() - model.return_value.deploy.assert_called()