Skip to content

Commit d63deea

Browse files
authored
Update fit() to work with frameworks (aws#27)
1 parent ae1dd2f commit d63deea

File tree

5 files changed

+223
-66
lines changed

5 files changed

+223
-66
lines changed

src/sagemaker/session.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import print_function, absolute_import
1414

15+
import json
1516
import logging
16-
import re
17-
1817
import os
18+
import re
1919
import sys
2020
import time
2121

2222
import boto3
23-
import json
23+
import botocore.config
2424
import six
2525
import yaml
26-
import botocore.config
2726
from botocore.exceptions import ClientError
2827

2928
from sagemaker.user_agent import prepend_user_agent
@@ -257,22 +256,22 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
257256
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
258257
self.sagemaker_client.create_training_job(**train_request)
259258

260-
def tune(self, job_name, strategy, objective, metric_name,
259+
def tune(self, job_name, strategy, objective_type, objective_metric_name,
261260
max_jobs, max_parallel_jobs, parameter_ranges,
262-
static_hp, image, input_mode, metric_definitions,
261+
static_hyperparameters, image, input_mode, metric_definitions,
263262
role, input_config, output_config, resource_config, stop_condition):
264-
"""Create an Amazon SageMaker HPO job.
263+
"""Create an Amazon SageMaker hyperparameter tuning job
265264
266265
Args:
267266
job_name (str): Name of the tuning job being created.
268267
strategy (str): Strategy to be used.
269-
objective (str): Minimize/Maximize
270-
metric_name (str): Name of the metric to use when evaluating training job.
268+
objective_type (str): Minimize/Maximize
269+
objective_metric_name (str): Name of the metric to use when evaluating training job.
271270
max_jobs (int): Maximum total number of jobs to start.
272271
max_parallel_jobs (int): Maximum number of parallel jobs to start.
273272
parameter_ranges (dict): Parameter ranges in a dictionary of types: Continuous, Integer, Categorical
274-
static_hp (dict): Hyperparameters for model training. The hyperparameters are made accessible as
275-
a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
273+
static_hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible
274+
as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
276275
keys and values, but ``str()`` will be called to convert them before training.
277276
image (str): Docker image containing training code.
278277
input_mode (str): The input mode that the algorithm supports. Valid modes:
@@ -293,30 +292,27 @@ def tune(self, job_name, strategy, objective, metric_name,
293292
instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
294293
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
295294
service like ``MaxRuntimeInSeconds``.
296-
297-
Returns:
298-
299295
"""
300-
301296
tune_request = {
302297
'HyperParameterTuningJobName': job_name,
303298
'HyperParameterTuningJobConfig': {
304299
'Strategy': strategy,
305300
'HyperParameterTuningJobObjective': {
306-
'Type': objective,
307-
'MetricName': metric_name,
301+
'Type': objective_type,
302+
'MetricName': objective_metric_name,
308303
},
309304
'ResourceLimits': {
310305
'MaxNumberOfTrainingJobs': max_jobs,
311-
'MaxParallelTrainingJobs': max_parallel_jobs
306+
'MaxParallelTrainingJobs': max_parallel_jobs,
312307
},
313-
'ParameterRanges': parameter_ranges
308+
'ParameterRanges': parameter_ranges,
314309
},
315310
'TrainingJobDefinition': {
316-
'StaticHyperParameters': static_hp,
311+
'StaticHyperParameters': static_hyperparameters,
317312
'AlgorithmSpecification': {
318313
'TrainingImage': image,
319-
'TrainingInputMode': input_mode
314+
'TrainingInputMode': input_mode,
315+
'MetricDefinitions': metric_definitions,
320316
},
321317
'RoleArn': role,
322318
'InputDataConfig': input_config,
@@ -329,7 +325,7 @@ def tune(self, job_name, strategy, objective, metric_name,
329325
if metric_definitions is not None:
330326
tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
331327

332-
LOGGER.info('Creating tuning-job with name: {}'.format(job_name))
328+
LOGGER.info('Creating hyperparameter tuning job with name: {}'.format(job_name))
333329
LOGGER.debug('tune request: {}'.format(json.dumps(tune_request, indent=4)))
334330
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
335331

src/sagemaker/tuner.py

+59-27
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617

18+
from sagemaker.estimator import Framework
1719
from sagemaker.job import _Job
1820
from sagemaker.utils import base_name_from_image, name_from_base
1921

@@ -51,6 +53,9 @@ def as_tuning_range(self, name):
5153
return {'Name': name,
5254
'Values': self.values}
5355

56+
def as_json_range(self, name):
57+
return {'Name': name, 'Values': [json.dumps(v) for v in self.values]}
58+
5459

5560
class IntegerParameter(_ParameterRange):
5661
__name__ = 'Integer'
@@ -60,31 +65,58 @@ def __init__(self, min_value, max_value):
6065

6166

6267
class HyperparameterTuner(object):
63-
__objectives__ = ['Minimize', 'Maximize']
68+
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
69+
SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
6470

6571
def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metric_definitions, strategy='Bayesian',
6672
objective_type='Maximize', max_jobs=1, max_parallel_jobs=1, base_tuning_job_name=None):
67-
if objective_type not in HyperparameterTuner.__objectives__:
68-
raise ValueError("Unsupported 'objective' values")
73+
self._hyperparameter_ranges = hyperparameter_ranges
74+
if self._hyperparameter_ranges is None or len(self._hyperparameter_ranges) == 0:
75+
raise ValueError('Need to specify hyperparameter ranges')
6976

7077
self.estimator = estimator
7178
self.objective_metric_name = objective_metric_name
72-
self._hyperparameter_ranges = hyperparameter_ranges
79+
self.metric_definitions = metric_definitions
80+
7381
self.strategy = strategy
7482
self.objective_type = objective_type
83+
7584
self.max_jobs = max_jobs
7685
self.max_parallel_jobs = max_parallel_jobs
7786
self.tuning_job_name = base_tuning_job_name
7887
self.metric_definitions = metric_definitions
7988
self.latest_tuning_job = None
8089
self._validate_parameter_ranges()
8190

82-
def fit(self, inputs):
83-
"""Create tuning job
91+
def prepare_for_training(self):
92+
# TODO: Change this so that it can handle unicode in Python 2
93+
self.static_hyperparameters = {str(k): str(v) for (k, v) in self.estimator.hyperparameters().items()}
94+
for hyperparameter_name in self._hyperparameter_ranges.keys():
95+
self.static_hyperparameters.pop(hyperparameter_name, None)
96+
97+
# For attach() to know what estimator to use
98+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = self.estimator.__class__.__name__
99+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = self.estimator.__module__
100+
101+
def fit(self, inputs, job_name=None, **kwargs):
102+
"""Start a hyperparameter tuning job.
84103
85104
Args:
86-
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
105+
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
106+
job_name (str): Job name
107+
**kwargs: Other arguments
87108
"""
109+
# TODO: I think I have to move RecordSet to its own file
110+
from sagemaker.amazon.amazon_estimator import RecordSet
111+
112+
# 1P estimators require a RecordSet object
113+
if isinstance(inputs, RecordSet):
114+
self.estimator.prepare_for_training(inputs, **kwargs)
115+
inputs = inputs.data_channel()
116+
else:
117+
self.estimator.prepare_for_training(**kwargs)
118+
119+
self.prepare_for_training()
88120
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
89121

90122
def stop_tuning_job(self):
@@ -101,15 +133,20 @@ def hyperparameter_ranges(self):
101133
"""Return collections of ``ParameterRanges``
102134
103135
Returns:
104-
dict: ParameterRanges suitable for tuning job.
136+
dict: ParameterRanges suitable for a hyperparameter tuning job.
105137
"""
106138
hyperparameter_ranges = dict()
107139
for range_type in _ParameterRange.__all_types__:
108-
parameter_range = []
140+
parameter_ranges = []
109141
for parameter_name, parameter in self._hyperparameter_ranges.items():
110142
if parameter is not None and parameter.__name__ == range_type:
111-
parameter_range.append(parameter.as_tuning_range(parameter_name))
112-
hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_range
143+
# Categorical parameters needed to be serialized as JSON for our framework containers
144+
if isinstance(parameter, CategoricalParameter) and isinstance(self.estimator, Framework):
145+
tuning_range = parameter.as_json_range(parameter_name)
146+
else:
147+
tuning_range = parameter.as_tuning_range(parameter_name)
148+
parameter_ranges.append(tuning_range)
149+
hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_ranges
113150
return hyperparameter_ranges
114151

115152
def _validate_parameter_ranges(self):
@@ -138,43 +175,38 @@ def _validate_parameter_ranges(self):
138175

139176

140177
class _TuningJob(_Job):
141-
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
142-
SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
143-
144178
def __init__(self, sagemaker_session, tuning_job_name):
145179
super(_TuningJob, self).__init__(sagemaker_session, tuning_job_name)
146180

147181
@classmethod
148182
def start_new(cls, tuner, inputs):
149-
"""Create a new Amazon SageMaker tuning job from the HyperparameterTuner.
183+
"""Create a new Amazon SageMaker hyperparameter tuning job from the HyperparameterTuner.
150184
151185
Args:
152-
tuner (sagemaker.tuner.HyperparameterTuner): Tuner object created by the user.
153-
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
186+
tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner object created by the user.
187+
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
154188
155189
Returns:
156190
sagemaker.tuner._TuningJob: Constructed object that captures all information about the started job.
157191
"""
158192
config = _Job._load_config(inputs, tuner.estimator)
159193

160-
static_hyperparameters = {str(k): str(v) for (k, v) in tuner.estimator.hyperparameters().items()}
161-
for hyperparameter_name in tuner._hyperparameter_ranges.keys():
162-
static_hyperparameters.pop(hyperparameter_name, None)
163-
164-
static_hyperparameters[cls.SAGEMAKER_ESTIMATOR_CLASS_NAME] = tuner.estimator.__class__.__name__
165-
static_hyperparameters[cls.SAGEMAKER_ESTIMATOR_MODULE] = tuner.estimator.__module__
166-
167194
base_name = tuner.estimator.base_job_name or base_name_from_image(tuner.estimator.train_image())
168195
tuning_job_name = name_from_base(base_name)
169196

197+
# TODO: Update name generation so that the base name isn't limited to so few characters
198+
if len(tuning_job_name) > 32:
199+
raise ValueError('Tuning job name too long - must be 32 characters or fewer: {}'.format(tuning_job_name))
200+
170201
tuner.estimator.sagemaker_session.tune(job_name=tuning_job_name, strategy=tuner.strategy,
171-
objective=tuner.objective_type, metric_name=tuner.objective_metric_name,
202+
objective_type=tuner.objective_type,
203+
objective_metric_name=tuner.objective_metric_name,
172204
max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
173205
parameter_ranges=tuner.hyperparameter_ranges(),
174-
static_hp=static_hyperparameters,
206+
static_hyperparameters=tuner.static_hyperparameters,
175207
image=tuner.estimator.train_image(),
176208
input_mode=tuner.estimator.input_mode,
177-
metric_definitions=tuner.estimator.metric_definitions,
209+
metric_definitions=tuner.metric_definitions,
178210
role=(config['role']), input_config=(config['input_config']),
179211
output_config=(config['output_config']),
180212
resource_config=(config['resource_config']),

tests/data/mxnet_mnist/tuning.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import gzip
2+
import logging
3+
import os
4+
import struct
5+
6+
import mxnet as mx
7+
import numpy as np
8+
9+
10+
def load_data(path):
11+
with gzip.open(find_file(path, "labels.gz")) as flbl:
12+
struct.unpack(">II", flbl.read(8))
13+
labels = np.fromstring(flbl.read(), dtype=np.int8)
14+
with gzip.open(find_file(path, "images.gz")) as fimg:
15+
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
16+
images = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(labels), rows, cols)
17+
images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255
18+
return labels, images
19+
20+
21+
def find_file(root_path, file_name):
22+
for root, dirs, files in os.walk(root_path):
23+
if file_name in files:
24+
return os.path.join(root, file_name)
25+
26+
27+
def build_graph():
28+
data = mx.sym.var('data')
29+
data = mx.sym.flatten(data=data)
30+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
31+
act1 = mx.sym.Activation(data=fc1, act_type="relu")
32+
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64)
33+
act2 = mx.sym.Activation(data=fc2, act_type="relu")
34+
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
35+
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
36+
37+
38+
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
39+
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
40+
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
41+
42+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
43+
# to do parallel training.
44+
shard_size = len(train_images) // len(hosts)
45+
for i, host in enumerate(hosts):
46+
if host == current_host:
47+
start = shard_size * i
48+
end = start + shard_size
49+
break
50+
51+
batch_size = 100
52+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True)
53+
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
54+
logging.getLogger().setLevel(logging.DEBUG)
55+
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
56+
mlp_model = mx.mod.Module(
57+
symbol=build_graph(),
58+
context=get_train_context(num_cpus, num_gpus))
59+
mlp_model.fit(train_iter,
60+
eval_data=val_iter,
61+
kvstore=kvstore,
62+
optimizer='sgd',
63+
optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))},
64+
eval_metric='acc',
65+
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
66+
num_epoch=25)
67+
return mlp_model
68+
69+
70+
def get_train_context(num_cpus, num_gpus):
71+
if num_gpus > 0:
72+
return mx.gpu()
73+
return mx.cpu()

0 commit comments

Comments
 (0)