Skip to content

Commit 80c4e92

Browse files
author
Jonathan Esterhazy
committed
address PR comments
1 parent b2f7e9e commit 80c4e92

File tree

7 files changed

+347
-127
lines changed

7 files changed

+347
-127
lines changed

src/sagemaker/cli/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sagemaker.cli.host import host
2-
from sagemaker.cli.train import train
1+
import sagemaker.cli.mxnet
2+
import sagemaker.cli.tensorflow
33

4-
__all__ = [host, train]
4+
__all__ = [mxnet, tensorflow]

src/sagemaker/cli/host.py renamed to src/sagemaker/cli/common.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
from __future__ import absolute_import
22

3+
import json
4+
import logging
35
import os
46
import shutil
57
import tarfile
68
import tempfile
79

810
import sagemaker
911

12+
logger = logging.getLogger(__name__)
1013

11-
def host(args):
12-
return HostingCommand(args).start()
1314

14-
15-
class HostingCommand(object):
15+
class HostCommand(object):
1616
def __init__(self, args):
1717
self.endpoint_name = args.job_name
1818
self.bucket = args.bucket_name # may be None
1919
self.role_name = args.role_name
20+
self.python = args.python
2021
self.data = args.data
2122
self.script = args.script
22-
self.python = args.python
2323
self.instance_type = args.instance_type
2424
self.instance_count = args.instance_count
25-
self.framework = 'tensorflow' if args.tf else 'mxnet' if args.mx else 'undefined'
2625
self.environment = {k: v for k, v in (kv.split('=') for kv in args.env)}
2726

2827
self.session = sagemaker.Session()
@@ -51,16 +50,7 @@ def create_model_archive(src):
5150
return archive
5251

5352
def create_model(self, model_url):
54-
if self.framework == 'tensorflow':
55-
from sagemaker.tensorflow.model import TensorFlowModel
56-
return TensorFlowModel(model_data=model_url, role=self.role_name, entry_point=self.script,
57-
name=self.endpoint_name, env=self.environment)
58-
elif self.framework == 'mxnet':
59-
from sagemaker.mxnet.model import MXNetModel
60-
return MXNetModel(model_data=model_url, role=self.role_name, entry_point=self.script,
61-
py_version=self.python, name=self.endpoint_name, env=self.environment)
62-
else:
63-
raise ValueError('unsupported framework value: {}'.format(self.framework))
53+
raise NotImplementedError # subclasses must override
6454

6555
def start(self):
6656
model_url = self.upload_model()
@@ -69,3 +59,42 @@ def start(self):
6959
instance_type=self.instance_type)
7060

7161
return predictor
62+
63+
64+
class TrainCommand(object):
65+
def __init__(self, args):
66+
self.job_name = args.job_name
67+
self.bucket = args.bucket_name # may be None
68+
self.role_name = args.role_name
69+
self.python = args.python
70+
self.data = args.data
71+
self.script = args.script
72+
self.instance_type = args.instance_type
73+
self.instance_count = args.instance_count
74+
self.hyperparameters = self.load_hyperparameters(args.hyperparameters)
75+
76+
self.session = sagemaker.Session()
77+
78+
@staticmethod
79+
def load_hyperparameters(src):
80+
hp = {}
81+
if src and os.path.exists(src):
82+
with open(src, 'r') as f:
83+
hp = json.load(f)
84+
return hp
85+
86+
def upload_training_data(self):
87+
prefix = '{}/data'.format(self.job_name)
88+
data_url = self.session.upload_data(path=self.data, bucket=self.bucket, key_prefix=prefix)
89+
return data_url
90+
91+
def create_estimator(self):
92+
raise NotImplementedError # subclasses must override
93+
94+
def start(self):
95+
data_url = self.upload_training_data()
96+
estimator = self.create_estimator()
97+
estimator.fit(data_url)
98+
logger.debug('code location: {}'.format(estimator.uploaded_code.s3_prefix))
99+
logger.debug('model location: {}{}/output/model.tar.gz'.format(estimator.output_path,
100+
estimator._current_job_name))

src/sagemaker/cli/main.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,66 @@
1010

1111

1212
def parse_arguments(args):
13-
# common arguments
14-
common_parser = argparse.ArgumentParser(add_help=False)
15-
16-
# image-related settings
17-
image_mtx = common_parser.add_mutually_exclusive_group(required=True)
18-
image_mtx.add_argument('--tf', help='use a TensorFlow container image', action='store_true')
19-
image_mtx.add_argument('--mx', help='use an MXNet container image', action='store_true')
13+
parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints')
14+
parser.set_defaults(func=lambda x: parser.print_usage())
2015

21-
# path to data and script files
16+
# common args for training/hosting/all frameworks
17+
common_parser = argparse.ArgumentParser(add_help=False)
2218
common_parser.add_argument('--data', help='path to training data or model files', type=str, default='./data')
2319
common_parser.add_argument('--script', help='path to script', type=str, default='./script.py')
2420
common_parser.add_argument('--job-name', help='job or endpoint name', type=str, default=None)
2521
common_parser.add_argument('--bucket-name', help='S3 bucket', type=str, default=None)
2622
common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str,
2723
default='AmazonSageMakerFullAccess')
24+
common_parser.add_argument('--python', help='python version', type=str, default='py2')
2825

2926
instance_group = common_parser.add_argument_group('instance settings')
3027
instance_group.add_argument('--instance-type', type=str, help='instance type', default='ml.m4.xlarge')
3128
instance_group.add_argument('--instance-count', type=int, help='instance count', default=1)
3229

33-
image_group = common_parser.add_argument_group('other container image settings')
34-
image_group.add_argument('--python', help='python version (mxnet only)', type=str, default='py2')
35-
36-
parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints')
37-
parser.set_defaults(func=lambda x: parser.print_usage())
38-
39-
log_group = parser.add_argument_group('log settings')
30+
log_group = common_parser.add_argument_group('optional log settings')
4031
log_group.add_argument('--log-level', help='log level for this command', type=str, default='info')
4132
log_group.add_argument('--botocore-log-level', help='log level for botocore', type=str, default='warning')
4233

34+
# common training args
35+
common_train_parser = argparse.ArgumentParser(add_help=False)
36+
common_train_parser.add_argument('--hyperparameters', help='path to training hyperparameters file',
37+
type=str, default='./hyperparameters.json')
38+
39+
# common hosting args
40+
common_host_parser = argparse.ArgumentParser(add_help=False)
41+
common_host_parser.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[])
42+
4343
subparsers = parser.add_subparsers()
44-
train_parser = subparsers.add_parser('train', help='start a training job', parents=[common_parser])
45-
train_group = train_parser.add_argument_group('training settings')
46-
train_group.add_argument('--hyperparameters', help='path to training hyperparameters file',
47-
type=str, default='./hyperparameters.json')
48-
train_group.add_argument('--training-steps',
49-
help='number of training steps (tensorflow only)', type=int, default=None)
50-
train_group.add_argument('--evaluation-steps',
51-
help='number of evaluation steps (tensorflow only)', type=int, default=None)
52-
train_parser.set_defaults(mode='train')
53-
train_parser.set_defaults(func=sagemaker.cli.train)
54-
55-
host_parser = subparsers.add_parser('host', help='start a hosting endpoint', parents=[common_parser])
56-
host_group = host_parser.add_argument_group('hosting settings')
57-
host_group.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[])
58-
train_parser.set_defaults(mode='host')
59-
host_parser.set_defaults(func=sagemaker.cli.host)
44+
45+
# framework/algo subcommands
46+
mxnet_parser = subparsers.add_parser('mxnet', help='use MXNet', parents=[])
47+
mxnet_subparsers = mxnet_parser.add_subparsers()
48+
mxnet_train_parser = mxnet_subparsers.add_parser('train',
49+
help='start a training job',
50+
parents=[common_parser, common_train_parser])
51+
mxnet_train_parser.set_defaults(func=sagemaker.cli.mxnet.train)
52+
53+
mxnet_host_parser = mxnet_subparsers.add_parser('host',
54+
help='start a hosting endpoint',
55+
parents=[common_parser, common_host_parser])
56+
mxnet_host_parser.set_defaults(func=sagemaker.cli.mxnet.host)
57+
58+
tensorflow_parser = subparsers.add_parser('tensorflow', help='use TensorFlow', parents=[])
59+
tensorflow_subparsers = tensorflow_parser.add_subparsers()
60+
tensorflow_train_parser = tensorflow_subparsers.add_parser('train',
61+
help='start a training job',
62+
parents=[common_parser, common_train_parser])
63+
tensorflow_train_parser.add_argument('--training-steps',
64+
help='number of training steps (tensorflow only)', type=int, default=None)
65+
tensorflow_train_parser.add_argument('--evaluation-steps',
66+
help='number of evaluation steps (tensorflow only)', type=int, default=None)
67+
tensorflow_train_parser.set_defaults(func=sagemaker.cli.tensorflow.train)
68+
69+
tensorflow_host_parser = tensorflow_subparsers.add_parser('host',
70+
help='start a hosting endpoint',
71+
parents=[common_parser, common_host_parser])
72+
tensorflow_host_parser.set_defaults(func=sagemaker.cli.tensorflow.host)
6073

6174
return parser.parse_args(args)
6275

src/sagemaker/cli/mxnet.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from sagemaker.cli.common import HostCommand, TrainCommand
2+
3+
4+
def train(args):
5+
MXNetTrainCommand(args).start()
6+
7+
8+
def host(args):
9+
MXNetHostCommand(args).start()
10+
11+
12+
class MXNetTrainCommand(TrainCommand):
13+
def __init__(self, args):
14+
super(MXNetTrainCommand, self).__init__(args)
15+
16+
def create_estimator(self):
17+
from sagemaker.mxnet.estimator import MXNet
18+
return MXNet(self.script,
19+
role=self.role_name,
20+
base_job_name=self.job_name,
21+
train_instance_count=self.instance_count,
22+
train_instance_type=self.instance_type,
23+
hyperparameters=self.hyperparameters,
24+
py_version=self.python)
25+
26+
27+
class MXNetHostCommand(HostCommand):
28+
def __init__(self, args):
29+
super(MXNetHostCommand, self).__init__(args)
30+
31+
def create_model(self, model_url):
32+
from sagemaker.mxnet.model import MXNetModel
33+
return MXNetModel(model_data=model_url, role=self.role_name, entry_point=self.script,
34+
py_version=self.python, name=self.endpoint_name, env=self.environment)

src/sagemaker/cli/tensorflow.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from sagemaker.cli.common import HostCommand, TrainCommand
2+
3+
4+
def train(args):
5+
TensorFlowTrainCommand(args).start()
6+
7+
8+
def host(args):
9+
TensorFlowHostCommand(args).start()
10+
11+
12+
class TensorFlowTrainCommand(TrainCommand):
13+
def __init__(self, args):
14+
super(TensorFlowTrainCommand, self).__init__(args)
15+
self.training_steps = args.training_steps
16+
self.evaluation_steps = args.evaluation_steps
17+
18+
def create_estimator(self):
19+
from sagemaker.tensorflow import TensorFlow
20+
return TensorFlow(training_steps=self.training_steps,
21+
evaluation_steps=self.evaluation_steps,
22+
py_version=self.python,
23+
entry_point=self.script,
24+
role=self.role_name,
25+
base_job_name=self.job_name,
26+
train_instance_count=self.instance_count,
27+
train_instance_type=self.instance_type,
28+
hyperparameters=self.hyperparameters)
29+
30+
31+
class TensorFlowHostCommand(HostCommand):
32+
def __init__(self, args):
33+
super(TensorFlowHostCommand, self).__init__(args)
34+
35+
def create_model(self, model_url):
36+
from sagemaker.tensorflow.model import TensorFlowModel
37+
return TensorFlowModel(model_data=model_url, role=self.role_name, entry_point=self.script,
38+
py_version=self.python, name=self.endpoint_name, env=self.environment)

src/sagemaker/cli/train.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

0 commit comments

Comments
 (0)