|
10 | 10 |
|
11 | 11 |
|
12 | 12 | def parse_arguments(args):
|
13 |
| - # common arguments |
14 |
| - common_parser = argparse.ArgumentParser(add_help=False) |
15 |
| - |
16 |
| - # image-related settings |
17 |
| - image_mtx = common_parser.add_mutually_exclusive_group(required=True) |
18 |
| - image_mtx.add_argument('--tf', help='use a TensorFlow container image', action='store_true') |
19 |
| - image_mtx.add_argument('--mx', help='use an MXNet container image', action='store_true') |
| 13 | + parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints') |
| 14 | + parser.set_defaults(func=lambda x: parser.print_usage()) |
20 | 15 |
|
21 |
| - # path to data and script files |
| 16 | + # common args for training/hosting/all frameworks |
| 17 | + common_parser = argparse.ArgumentParser(add_help=False) |
22 | 18 | common_parser.add_argument('--data', help='path to training data or model files', type=str, default='./data')
|
23 | 19 | common_parser.add_argument('--script', help='path to script', type=str, default='./script.py')
|
24 | 20 | common_parser.add_argument('--job-name', help='job or endpoint name', type=str, default=None)
|
25 | 21 | common_parser.add_argument('--bucket-name', help='S3 bucket', type=str, default=None)
|
26 | 22 | common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str,
|
27 | 23 | default='AmazonSageMakerFullAccess')
|
| 24 | + common_parser.add_argument('--python', help='python version', type=str, default='py2') |
28 | 25 |
|
29 | 26 | instance_group = common_parser.add_argument_group('instance settings')
|
30 | 27 | instance_group.add_argument('--instance-type', type=str, help='instance type', default='ml.m4.xlarge')
|
31 | 28 | instance_group.add_argument('--instance-count', type=int, help='instance count', default=1)
|
32 | 29 |
|
33 |
| - image_group = common_parser.add_argument_group('other container image settings') |
34 |
| - image_group.add_argument('--python', help='python version (mxnet only)', type=str, default='py2') |
35 |
| - |
36 |
| - parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints') |
37 |
| - parser.set_defaults(func=lambda x: parser.print_usage()) |
38 |
| - |
39 |
| - log_group = parser.add_argument_group('log settings') |
| 30 | + log_group = common_parser.add_argument_group('optional log settings') |
40 | 31 | log_group.add_argument('--log-level', help='log level for this command', type=str, default='info')
|
41 | 32 | log_group.add_argument('--botocore-log-level', help='log level for botocore', type=str, default='warning')
|
42 | 33 |
|
| 34 | + # common training args |
| 35 | + common_train_parser = argparse.ArgumentParser(add_help=False) |
| 36 | + common_train_parser.add_argument('--hyperparameters', help='path to training hyperparameters file', |
| 37 | + type=str, default='./hyperparameters.json') |
| 38 | + |
| 39 | + # common hosting args |
| 40 | + common_host_parser = argparse.ArgumentParser(add_help=False) |
| 41 | + common_host_parser.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[]) |
| 42 | + |
43 | 43 | subparsers = parser.add_subparsers()
|
44 |
| - train_parser = subparsers.add_parser('train', help='start a training job', parents=[common_parser]) |
45 |
| - train_group = train_parser.add_argument_group('training settings') |
46 |
| - train_group.add_argument('--hyperparameters', help='path to training hyperparameters file', |
47 |
| - type=str, default='./hyperparameters.json') |
48 |
| - train_group.add_argument('--training-steps', |
49 |
| - help='number of training steps (tensorflow only)', type=int, default=None) |
50 |
| - train_group.add_argument('--evaluation-steps', |
51 |
| - help='number of evaluation steps (tensorflow only)', type=int, default=None) |
52 |
| - train_parser.set_defaults(mode='train') |
53 |
| - train_parser.set_defaults(func=sagemaker.cli.train) |
54 |
| - |
55 |
| - host_parser = subparsers.add_parser('host', help='start a hosting endpoint', parents=[common_parser]) |
56 |
| - host_group = host_parser.add_argument_group('hosting settings') |
57 |
| - host_group.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[]) |
58 |
| - train_parser.set_defaults(mode='host') |
59 |
| - host_parser.set_defaults(func=sagemaker.cli.host) |
| 44 | + |
| 45 | + # framework/algo subcommands |
| 46 | + mxnet_parser = subparsers.add_parser('mxnet', help='use MXNet', parents=[]) |
| 47 | + mxnet_subparsers = mxnet_parser.add_subparsers() |
| 48 | + mxnet_train_parser = mxnet_subparsers.add_parser('train', |
| 49 | + help='start a training job', |
| 50 | + parents=[common_parser, common_train_parser]) |
| 51 | + mxnet_train_parser.set_defaults(func=sagemaker.cli.mxnet.train) |
| 52 | + |
| 53 | + mxnet_host_parser = mxnet_subparsers.add_parser('host', |
| 54 | + help='start a hosting endpoint', |
| 55 | + parents=[common_parser, common_host_parser]) |
| 56 | + mxnet_host_parser.set_defaults(func=sagemaker.cli.mxnet.host) |
| 57 | + |
| 58 | + tensorflow_parser = subparsers.add_parser('tensorflow', help='use TensorFlow', parents=[]) |
| 59 | + tensorflow_subparsers = tensorflow_parser.add_subparsers() |
| 60 | + tensorflow_train_parser = tensorflow_subparsers.add_parser('train', |
| 61 | + help='start a training job', |
| 62 | + parents=[common_parser, common_train_parser]) |
| 63 | + tensorflow_train_parser.add_argument('--training-steps', |
| 64 | + help='number of training steps (tensorflow only)', type=int, default=None) |
| 65 | + tensorflow_train_parser.add_argument('--evaluation-steps', |
| 66 | + help='number of evaluation steps (tensorflow only)', type=int, default=None) |
| 67 | + tensorflow_train_parser.set_defaults(func=sagemaker.cli.tensorflow.train) |
| 68 | + |
| 69 | + tensorflow_host_parser = tensorflow_subparsers.add_parser('host', |
| 70 | + help='start a hosting endpoint', |
| 71 | + parents=[common_parser, common_host_parser]) |
| 72 | + tensorflow_host_parser.set_defaults(func=sagemaker.cli.tensorflow.host) |
60 | 73 |
|
61 | 74 | return parser.parse_args(args)
|
62 | 75 |
|
|
0 commit comments