Skip to content

Commit ee741a5

Browse files
author
Jonathan Esterhazy
committed
remove default role name
1 parent 3dd7a14 commit ee741a5

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

src/sagemaker/cli/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ def parse_arguments(args):
1717

1818
# common args for training/hosting/all frameworks
1919
common_parser = argparse.ArgumentParser(add_help=False)
20+
common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str, required=True)
2021
common_parser.add_argument('--data', help='path to training data or model files', type=str, default='./data')
2122
common_parser.add_argument('--script', help='path to script', type=str, default='./script.py')
2223
common_parser.add_argument('--job-name', help='job or endpoint name', type=str, default=None)
2324
common_parser.add_argument('--bucket-name', help='S3 bucket for training/model data and script files',
2425
type=str, default=None)
25-
common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str,
26-
default='AmazonSageMakerFullAccess')
2726
common_parser.add_argument('--python', help='python version', type=str, default='py2')
2827

2928
instance_group = common_parser.add_argument_group('instance settings')

tests/unit/test_cli.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sagemaker.cli.main as cli
33
from mock import patch
44

5-
COMMON_ARGS = '--data mydata --script myscript --job-name myjob --bucket-name mybucket --role-name myrole ' + \
5+
COMMON_ARGS = '--role-name myrole --data mydata --script myscript --job-name myjob --bucket-name mybucket ' + \
66
'--python py3 --instance-type myinstance --instance-count 2'
77

88
TRAIN_ARGS = '--hyperparameters myhyperparameters.json'
@@ -17,7 +17,6 @@ def assert_common_defaults(args):
1717
assert args.script == './script.py'
1818
assert args.job_name is None
1919
assert args.bucket_name is None
20-
assert args.role_name is 'AmazonSageMakerFullAccess'
2120
assert args.python == 'py2'
2221
assert args.instance_type == 'ml.m4.xlarge'
2322
assert args.instance_count == 1
@@ -55,15 +54,15 @@ def assert_host_non_defaults(args):
5554

5655

5756
def test_args_mxnet_train_defaults():
58-
args = cli.parse_arguments('mxnet train'.split())
57+
args = cli.parse_arguments('mxnet train --role-name role'.split())
5958
assert_common_defaults(args)
6059
assert_train_defaults(args)
6160
assert args.func.__module__ == 'sagemaker.cli.mxnet'
6261
assert args.func.__name__ == 'train'
6362

6463

6564
def test_args_mxnet_train_non_defaults():
66-
args = cli.parse_arguments('mxnet train {} {} {}'
65+
args = cli.parse_arguments('mxnet train --role-name role {} {} {}'
6766
.format(COMMON_ARGS, TRAIN_ARGS, LOG_ARGS)
6867
.split())
6968
assert_common_non_defaults(args)
@@ -73,15 +72,15 @@ def test_args_mxnet_train_non_defaults():
7372

7473

7574
def test_args_mxnet_host_defaults():
76-
args = cli.parse_arguments('mxnet host'.split())
75+
args = cli.parse_arguments('mxnet host --role-name role'.split())
7776
assert_common_defaults(args)
7877
assert_host_defaults(args)
7978
assert args.func.__module__ == 'sagemaker.cli.mxnet'
8079
assert args.func.__name__ == 'host'
8180

8281

8382
def test_args_mxnet_host_non_defaults():
84-
args = cli.parse_arguments('mxnet host {} {} {}'
83+
args = cli.parse_arguments('mxnet host --role-name role {} {} {}'
8584
.format(COMMON_ARGS, HOST_ARGS, LOG_ARGS)
8685
.split())
8786
assert_common_non_defaults(args)
@@ -91,7 +90,7 @@ def test_args_mxnet_host_non_defaults():
9190

9291

9392
def test_args_tensorflow_train_defaults():
94-
args = cli.parse_arguments('tensorflow train'.split())
93+
args = cli.parse_arguments('tensorflow train --role-name role'.split())
9594
assert_common_defaults(args)
9695
assert_train_defaults(args)
9796
assert args.training_steps is None
@@ -101,7 +100,7 @@ def test_args_tensorflow_train_defaults():
101100

102101

103102
def test_args_tensorflow_train_non_defaults():
104-
args = cli.parse_arguments('tensorflow train --training-steps 10 --evaluation-steps 5 {} {} {}'
103+
args = cli.parse_arguments('tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {} {}'
105104
.format(COMMON_ARGS, TRAIN_ARGS, LOG_ARGS)
106105
.split())
107106
assert_common_non_defaults(args)
@@ -113,15 +112,15 @@ def test_args_tensorflow_train_non_defaults():
113112

114113

115114
def test_args_tensorflow_host_defaults():
116-
args = cli.parse_arguments('tensorflow host'.split())
115+
args = cli.parse_arguments('tensorflow host --role-name role'.split())
117116
assert_common_defaults(args)
118117
assert_host_defaults(args)
119118
assert args.func.__module__ == 'sagemaker.cli.tensorflow'
120119
assert args.func.__name__ == 'host'
121120

122121

123122
def test_args_tensorflow_host_non_defaults():
124-
args = cli.parse_arguments('tensorflow host {} {} {}'
123+
args = cli.parse_arguments('tensorflow host --role-name role {} {} {}'
125124
.format(COMMON_ARGS, HOST_ARGS, LOG_ARGS)
126125
.split())
127126
assert_common_non_defaults(args)
@@ -132,7 +131,7 @@ def test_args_tensorflow_host_non_defaults():
132131

133132
def test_args_invalid_framework():
134133
with pytest.raises(SystemExit):
135-
cli.parse_arguments('fakeframework train'.split())
134+
cli.parse_arguments('fakeframework train --role-name role'.split())
136135

137136

138137
def test_args_invalid_subcommand():
@@ -142,28 +141,28 @@ def test_args_invalid_subcommand():
142141

143142
def test_args_invalid_args():
144143
with pytest.raises(SystemExit):
145-
cli.parse_arguments('tensorflow train --notdata foo'.split())
144+
cli.parse_arguments('tensorflow train --role-name role --notdata foo'.split())
146145

147146

148147
def test_args_invalid_mxnet_python():
149148
with pytest.raises(SystemExit):
150-
cli.parse_arguments('mxnet train nython py2'.split())
149+
cli.parse_arguments('mxnet train --role-name role nython py2'.split())
151150

152151

153152
def test_args_invalid_host_args_in_train():
154153
with pytest.raises(SystemExit):
155-
cli.parse_arguments('mxnet train --env FOO=bar'.split())
154+
cli.parse_arguments('mxnet train --role-name role --env FOO=bar'.split())
156155

157156

158157
def test_args_invalid_train_args_in_host():
159158
with pytest.raises(SystemExit):
160-
cli.parse_arguments('tensorflow host --hyperparameters foo.json'.split())
159+
cli.parse_arguments('tensorflow host --role-name role --hyperparameters foo.json'.split())
161160

162161

163162
@patch('sagemaker.mxnet.estimator.MXNet')
164163
@patch('sagemaker.Session')
165164
def test_mxnet_train(session, estimator):
166-
args = cli.parse_arguments('mxnet train'.split())
165+
args = cli.parse_arguments('mxnet train --role-name role'.split())
167166
args.func(args)
168167
session.return_value.upload_data.assert_called()
169168
estimator.assert_called()
@@ -174,7 +173,7 @@ def test_mxnet_train(session, estimator):
174173
@patch('sagemaker.cli.common.HostCommand.upload_model')
175174
@patch('sagemaker.Session')
176175
def test_mxnet_host(session, upload_model, model):
177-
args = cli.parse_arguments('mxnet host'.split())
176+
args = cli.parse_arguments('mxnet host --role-name role'.split())
178177
args.func(args)
179178
session.assert_called()
180179
upload_model.assert_called()

0 commit comments

Comments
 (0)