2
2
import sagemaker .cli .main as cli
3
3
from mock import patch
4
4
5
- COMMON_ARGS = '--data mydata --script myscript --job-name myjob --bucket-name mybucket --role-name myrole ' + \
5
+ COMMON_ARGS = '--role-name myrole -- data mydata --script myscript --job-name myjob --bucket-name mybucket ' + \
6
6
'--python py3 --instance-type myinstance --instance-count 2'
7
7
8
8
TRAIN_ARGS = '--hyperparameters myhyperparameters.json'
@@ -17,7 +17,6 @@ def assert_common_defaults(args):
17
17
assert args .script == './script.py'
18
18
assert args .job_name is None
19
19
assert args .bucket_name is None
20
- assert args .role_name is 'AmazonSageMakerFullAccess'
21
20
assert args .python == 'py2'
22
21
assert args .instance_type == 'ml.m4.xlarge'
23
22
assert args .instance_count == 1
@@ -55,15 +54,15 @@ def assert_host_non_defaults(args):
55
54
56
55
57
56
def test_args_mxnet_train_defaults ():
58
- args = cli .parse_arguments ('mxnet train' .split ())
57
+ args = cli .parse_arguments ('mxnet train --role-name role ' .split ())
59
58
assert_common_defaults (args )
60
59
assert_train_defaults (args )
61
60
assert args .func .__module__ == 'sagemaker.cli.mxnet'
62
61
assert args .func .__name__ == 'train'
63
62
64
63
65
64
def test_args_mxnet_train_non_defaults ():
66
- args = cli .parse_arguments ('mxnet train {} {} {}'
65
+ args = cli .parse_arguments ('mxnet train --role-name role {} {} {}'
67
66
.format (COMMON_ARGS , TRAIN_ARGS , LOG_ARGS )
68
67
.split ())
69
68
assert_common_non_defaults (args )
@@ -73,15 +72,15 @@ def test_args_mxnet_train_non_defaults():
73
72
74
73
75
74
def test_args_mxnet_host_defaults ():
76
- args = cli .parse_arguments ('mxnet host' .split ())
75
+ args = cli .parse_arguments ('mxnet host --role-name role ' .split ())
77
76
assert_common_defaults (args )
78
77
assert_host_defaults (args )
79
78
assert args .func .__module__ == 'sagemaker.cli.mxnet'
80
79
assert args .func .__name__ == 'host'
81
80
82
81
83
82
def test_args_mxnet_host_non_defaults ():
84
- args = cli .parse_arguments ('mxnet host {} {} {}'
83
+ args = cli .parse_arguments ('mxnet host --role-name role {} {} {}'
85
84
.format (COMMON_ARGS , HOST_ARGS , LOG_ARGS )
86
85
.split ())
87
86
assert_common_non_defaults (args )
@@ -91,7 +90,7 @@ def test_args_mxnet_host_non_defaults():
91
90
92
91
93
92
def test_args_tensorflow_train_defaults ():
94
- args = cli .parse_arguments ('tensorflow train' .split ())
93
+ args = cli .parse_arguments ('tensorflow train --role-name role ' .split ())
95
94
assert_common_defaults (args )
96
95
assert_train_defaults (args )
97
96
assert args .training_steps is None
@@ -101,7 +100,7 @@ def test_args_tensorflow_train_defaults():
101
100
102
101
103
102
def test_args_tensorflow_train_non_defaults ():
104
- args = cli .parse_arguments ('tensorflow train --training-steps 10 --evaluation-steps 5 {} {} {}'
103
+ args = cli .parse_arguments ('tensorflow train --role-name role -- training-steps 10 --evaluation-steps 5 {} {} {}'
105
104
.format (COMMON_ARGS , TRAIN_ARGS , LOG_ARGS )
106
105
.split ())
107
106
assert_common_non_defaults (args )
@@ -113,15 +112,15 @@ def test_args_tensorflow_train_non_defaults():
113
112
114
113
115
114
def test_args_tensorflow_host_defaults ():
116
- args = cli .parse_arguments ('tensorflow host' .split ())
115
+ args = cli .parse_arguments ('tensorflow host --role-name role ' .split ())
117
116
assert_common_defaults (args )
118
117
assert_host_defaults (args )
119
118
assert args .func .__module__ == 'sagemaker.cli.tensorflow'
120
119
assert args .func .__name__ == 'host'
121
120
122
121
123
122
def test_args_tensorflow_host_non_defaults ():
124
- args = cli .parse_arguments ('tensorflow host {} {} {}'
123
+ args = cli .parse_arguments ('tensorflow host --role-name role {} {} {}'
125
124
.format (COMMON_ARGS , HOST_ARGS , LOG_ARGS )
126
125
.split ())
127
126
assert_common_non_defaults (args )
@@ -132,7 +131,7 @@ def test_args_tensorflow_host_non_defaults():
132
131
133
132
def test_args_invalid_framework ():
134
133
with pytest .raises (SystemExit ):
135
- cli .parse_arguments ('fakeframework train' .split ())
134
+ cli .parse_arguments ('fakeframework train --role-name role ' .split ())
136
135
137
136
138
137
def test_args_invalid_subcommand ():
@@ -142,28 +141,28 @@ def test_args_invalid_subcommand():
142
141
143
142
def test_args_invalid_args ():
144
143
with pytest .raises (SystemExit ):
145
- cli .parse_arguments ('tensorflow train --notdata foo' .split ())
144
+ cli .parse_arguments ('tensorflow train --role-name role -- notdata foo' .split ())
146
145
147
146
148
147
def test_args_invalid_mxnet_python ():
149
148
with pytest .raises (SystemExit ):
150
- cli .parse_arguments ('mxnet train nython py2' .split ())
149
+ cli .parse_arguments ('mxnet train --role-name role nython py2' .split ())
151
150
152
151
153
152
def test_args_invalid_host_args_in_train ():
154
153
with pytest .raises (SystemExit ):
155
- cli .parse_arguments ('mxnet train --env FOO=bar' .split ())
154
+ cli .parse_arguments ('mxnet train --role-name role -- env FOO=bar' .split ())
156
155
157
156
158
157
def test_args_invalid_train_args_in_host ():
159
158
with pytest .raises (SystemExit ):
160
- cli .parse_arguments ('tensorflow host --hyperparameters foo.json' .split ())
159
+ cli .parse_arguments ('tensorflow host --role-name role -- hyperparameters foo.json' .split ())
161
160
162
161
163
162
@patch ('sagemaker.mxnet.estimator.MXNet' )
164
163
@patch ('sagemaker.Session' )
165
164
def test_mxnet_train (session , estimator ):
166
- args = cli .parse_arguments ('mxnet train' .split ())
165
+ args = cli .parse_arguments ('mxnet train --role-name role ' .split ())
167
166
args .func (args )
168
167
session .return_value .upload_data .assert_called ()
169
168
estimator .assert_called ()
@@ -174,7 +173,7 @@ def test_mxnet_train(session, estimator):
174
173
@patch ('sagemaker.cli.common.HostCommand.upload_model' )
175
174
@patch ('sagemaker.Session' )
176
175
def test_mxnet_host (session , upload_model , model ):
177
- args = cli .parse_arguments ('mxnet host' .split ())
176
+ args = cli .parse_arguments ('mxnet host --role-name role ' .split ())
178
177
args .func (args )
179
178
session .assert_called ()
180
179
upload_model .assert_called ()
0 commit comments