Skip to content

Commit 4bac185

Browse files
authored
Add support for Hyperparameter Tuning Early Stopping (#550)
1 parent 3db7807 commit 4bac185

File tree

10 files changed

+91
-18
lines changed

10 files changed

+91
-18
lines changed

CHANGELOG.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
CHANGELOG
33
=========
44

5-
1.16.2.dev
6-
==========
5+
1.16.2
6+
======
77

88
* enhancement: Check for S3 paths being passed as entry point
99
* feature: Add support for AugmentedManifestFile and ShuffleConfig
@@ -15,6 +15,7 @@ CHANGELOG
1515
* bug-fix: Update PyYAML version to avoid conflicts with docker-compose
1616
* doc-fix: Correct the numbered list in the table of contents
1717
* doc-fix: Add Airflow API documentation
18+
* feature: HyperparameterTuner: add Early Stopping support
1819

1920
1.16.1.post1
2021
============

README.rst

+16
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,22 @@ A hyperparameter range can be one of three types: continuous, integer, or catego
614614
The SageMaker Python SDK provides corresponding classes for defining these different types.
615615
You can define up to 20 hyperparameters to search over, but each value of a categorical hyperparameter range counts against that limit.
616616
617+
By default, training job early stopping is turned off. To enable early stopping for the tuning job, you need to set the ``early_stopping_type`` parameter to ``Auto``:
618+
619+
.. code:: python
620+
621+
# Enable early stopping
622+
my_tuner = HyperparameterTuner(estimator=my_estimator, # previously-configured Estimator object
623+
objective_metric_name='validation-accuracy',
624+
hyperparameter_ranges={'learning-rate': ContinuousParameter(0.05, 0.06)},
625+
metric_definitions=[{'Name': 'validation-accuracy', 'Regex': 'validation-accuracy=(\d\.\d+)'}],
626+
max_jobs=100,
627+
max_parallel_jobs=10,
628+
early_stopping_type='Auto')
629+
630+
When early stopping is turned on, Amazon SageMaker will automatically stop a training job if it appears unlikely to produce a model of better quality than other jobs.
631+
If not using built-in Amazon SageMaker algorithms, note that, for early stopping to be effective, the objective metric should be emitted at epoch level.
632+
617633
If you are using an Amazon SageMaker built-in algorithm, you don't need to pass in anything for ``metric_definitions``.
618634
In addition, the ``fit()`` call uses a list of ``RecordSet`` objects instead of a dictionary:
619635

doc/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __getattr__(cls, name):
3232
'numpy', 'scipy', 'scipy.sparse']
3333
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
3434

35-
version = '1.16.1.post1'
35+
version = '1.16.2'
3636
project = u'sagemaker'
3737

3838
# Add any Sphinx extension module names here, as strings. They can be extensions

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def read(fname):
3333

3434

3535
# Declare minimal set for installation
36-
required_packages = ['boto3>=1.9.55', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
36+
required_packages = ['boto3>=1.9.64', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
3737
'urllib3>=1.21', 'PyYAML>=3.2, <4', 'protobuf3-to-dict>=0.1.5',
3838
'docker-compose>=1.23.0', 'requests>=2.20.0, <2.21']
3939

src/sagemaker/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@
3939
from sagemaker.session import s3_input # noqa: F401
4040
from sagemaker.session import get_execution_role # noqa: F401
4141

42-
__version__ = '1.16.1.post1'
42+
__version__ = '1.16.2'

src/sagemaker/session.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
350350
max_jobs, max_parallel_jobs, parameter_ranges,
351351
static_hyperparameters, input_mode, metric_definitions,
352352
role, input_config, output_config, resource_config, stop_condition, tags,
353-
warm_start_config, enable_network_isolation=False, image=None, algorithm_arn=None):
353+
warm_start_config, enable_network_isolation=False, image=None, algorithm_arn=None,
354+
early_stopping_type='Off'):
354355
"""Create an Amazon SageMaker hyperparameter tuning job
355356
356357
Args:
@@ -396,6 +397,9 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
396397
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
397398
warm_start_config (dict): Configuration defining the type of warm start and
398399
other required configurations.
400+
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
401+
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be attempted.
402+
If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to.
399403
"""
400404
tune_request = {
401405
'HyperParameterTuningJobName': job_name,
@@ -410,6 +414,7 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
410414
'MaxParallelTrainingJobs': max_parallel_jobs,
411415
},
412416
'ParameterRanges': parameter_ranges,
417+
'TrainingJobEarlyStoppingType': early_stopping_type,
413418
},
414419
'TrainingJobDefinition': {
415420
'StaticHyperParameters': static_hyperparameters,

src/sagemaker/tuner.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class HyperparameterTuner(object):
165165

166166
def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metric_definitions=None,
167167
strategy='Bayesian', objective_type='Maximize', max_jobs=1, max_parallel_jobs=1,
168-
tags=None, base_tuning_job_name=None, warm_start_config=None):
168+
tags=None, base_tuning_job_name=None, warm_start_config=None, early_stopping_type='Off'):
169169
"""Initialize a ``HyperparameterTuner``. It takes an estimator to obtain configuration information
170170
for training jobs that are created as the result of a hyperparameter tuning job.
171171
@@ -194,6 +194,9 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
194194
a default job name is generated, based on the training image name and current timestamp.
195195
warm_start_config (sagemaker.tuner.WarmStartConfig): A ``WarmStartConfig`` object that has been initialized
196196
with the configuration defining the nature of warm start tuning job.
197+
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
198+
Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted.
199+
If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to.
197200
"""
198201
self._hyperparameter_ranges = hyperparameter_ranges
199202
if self._hyperparameter_ranges is None or len(self._hyperparameter_ranges) == 0:
@@ -214,6 +217,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
214217
self._current_job_name = None
215218
self.latest_tuning_job = None
216219
self.warm_start_config = warm_start_config
220+
self.early_stopping_type = early_stopping_type
217221

218222
def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
219223
if job_name is not None:
@@ -445,7 +449,8 @@ def _prepare_init_params_from_job_description(cls, job_details):
445449
'strategy': tuning_config['Strategy'],
446450
'max_jobs': tuning_config['ResourceLimits']['MaxNumberOfTrainingJobs'],
447451
'max_parallel_jobs': tuning_config['ResourceLimits']['MaxParallelTrainingJobs'],
448-
'warm_start_config': WarmStartConfig.from_job_desc(job_details.get('WarmStartConfig', None))
452+
'warm_start_config': WarmStartConfig.from_job_desc(job_details.get('WarmStartConfig', None)),
453+
'early_stopping_type': tuning_config['TrainingJobEarlyStoppingType']
449454
}
450455

451456
@classmethod
@@ -625,6 +630,7 @@ def start_new(cls, tuner, inputs):
625630
tuner_args['metric_definitions'] = tuner.metric_definitions
626631
tuner_args['tags'] = tuner.tags
627632
tuner_args['warm_start_config'] = warm_start_config_req
633+
tuner_args['early_stopping_type'] = tuner.early_stopping_type
628634

629635
del tuner_args['vpc_config']
630636
if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):

tests/integ/test_tuner.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,16 @@ def hyperparameter_ranges():
8383

8484
def _tune_and_deploy(kmeans_estimator, kmeans_train_set, sagemaker_session,
8585
hyperparameter_ranges=None, job_name=None,
86-
warm_start_config=None):
86+
warm_start_config=None, early_stopping_type='Off'):
8787
tuner = _tune(kmeans_estimator, kmeans_train_set,
8888
hyperparameter_ranges=hyperparameter_ranges, warm_start_config=warm_start_config,
89-
job_name=job_name)
90-
_deploy(kmeans_train_set, sagemaker_session, tuner)
89+
job_name=job_name, early_stopping_type=early_stopping_type)
90+
_deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type)
9191

9292

93-
def _deploy(kmeans_train_set, sagemaker_session, tuner):
93+
def _deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type):
9494
best_training_job = tuner.best_training_job()
95+
assert tuner.early_stopping_type == early_stopping_type
9596
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
9697
predictor = tuner.deploy(1, 'ml.c4.xlarge')
9798

@@ -105,7 +106,7 @@ def _deploy(kmeans_train_set, sagemaker_session, tuner):
105106

106107
def _tune(kmeans_estimator, kmeans_train_set, tuner=None,
107108
hyperparameter_ranges=None, job_name=None, warm_start_config=None,
108-
wait_till_terminal=True, max_jobs=2, max_parallel_jobs=2):
109+
wait_till_terminal=True, max_jobs=2, max_parallel_jobs=2, early_stopping_type='Off'):
109110
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
110111

111112
if not tuner:
@@ -115,7 +116,8 @@ def _tune(kmeans_estimator, kmeans_train_set, tuner=None,
115116
objective_type='Minimize',
116117
max_jobs=max_jobs,
117118
max_parallel_jobs=max_parallel_jobs,
118-
warm_start_config=warm_start_config)
119+
warm_start_config=warm_start_config,
120+
early_stopping_type=early_stopping_type)
119121

120122
records = kmeans_estimator.record_set(kmeans_train_set[0][:100])
121123
test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel='test')
@@ -332,16 +334,23 @@ def test_tuning_lda(sagemaker_session):
332334
tuner = HyperparameterTuner(estimator=lda, objective_metric_name=objective_metric_name,
333335
hyperparameter_ranges=hyperparameter_ranges,
334336
objective_type='Maximize', max_jobs=2,
335-
max_parallel_jobs=2)
337+
max_parallel_jobs=2,
338+
early_stopping_type='Auto')
336339

337340
tuning_job_name = unique_name_from_base('test-lda', max_length=32)
338341
tuner.fit([record_set, test_record_set], mini_batch_size=1, job_name=tuning_job_name)
339342

340-
print('Started hyperparameter tuning job with name:' + tuner.latest_tuning_job.name)
343+
latest_tuning_job_name = tuner.latest_tuning_job.name
344+
345+
print('Started hyperparameter tuning job with name:' + latest_tuning_job_name)
341346

342347
time.sleep(15)
343348
tuner.wait()
344349

350+
desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client \
351+
.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=latest_tuning_job_name)
352+
assert desc['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] == 'Auto'
353+
345354
best_training_job = tuner.best_training_job()
346355
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
347356
predictor = tuner.deploy(1, 'ml.c4.xlarge')
@@ -555,7 +564,8 @@ def test_attach_tuning_pytorch(sagemaker_session):
555564

556565
tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges,
557566
metric_definitions,
558-
max_jobs=2, max_parallel_jobs=2)
567+
max_jobs=2, max_parallel_jobs=2,
568+
early_stopping_type='Auto')
559569

560570
training_data = estimator.sagemaker_session.upload_data(
561571
path=os.path.join(mnist_dir, 'training'),
@@ -571,6 +581,8 @@ def test_attach_tuning_pytorch(sagemaker_session):
571581

572582
attached_tuner = HyperparameterTuner.attach(tuning_job_name,
573583
sagemaker_session=sagemaker_session)
584+
assert attached_tuner.early_stopping_type == 'Auto'
585+
574586
best_training_job = tuner.best_training_job()
575587
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
576588
predictor = attached_tuner.deploy(1, 'ml.c4.xlarge')

tests/unit/test_session.py

+1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def test_train_pack_to_request(sagemaker_session):
301301
'MaxParallelTrainingJobs': 5,
302302
},
303303
'ParameterRanges': SAMPLE_PARAM_RANGES,
304+
'TrainingJobEarlyStoppingType': 'Off'
304305
},
305306
'TrainingJobDefinition': {
306307
'StaticHyperParameters': STATIC_HPs,

tests/unit/test_tuner.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@
7474
'MinValue': '10',
7575
},
7676
]
77-
}
77+
},
78+
'TrainingJobEarlyStoppingType': 'Off'
7879
},
7980
'HyperParameterTuningJobName': JOB_NAME,
8081
'TrainingJobDefinition': {
@@ -241,9 +242,26 @@ def test_fit_pca(sagemaker_session, tuner):
241242
assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1
242243
assert tune_kwargs['job_name'].startswith('pca')
243244
assert tune_kwargs['tags'] == tags
245+
assert tune_kwargs['early_stopping_type'] == 'Off'
244246
assert tuner.estimator.mini_batch_size == 9999
245247

246248

249+
def test_fit_pca_with_early_stopping(sagemaker_session, tuner):
250+
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
251+
base_job_name='pca', sagemaker_session=sagemaker_session)
252+
253+
tuner.estimator = pca
254+
tuner.early_stopping_type = 'Auto'
255+
256+
records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1)
257+
tuner.fit(records, mini_batch_size=9999)
258+
259+
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
260+
261+
assert tune_kwargs['job_name'].startswith('pca')
262+
assert tune_kwargs['early_stopping_type'] == 'Auto'
263+
264+
247265
def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session):
248266
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
249267
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
@@ -257,6 +275,7 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
257275
assert tuner.metric_definitions == METRIC_DEFINTIONS
258276
assert tuner.strategy == 'Bayesian'
259277
assert tuner.objective_type == 'Minimize'
278+
assert tuner.early_stopping_type == 'Off'
260279

261280
assert isinstance(tuner.estimator, PCA)
262281
assert tuner.estimator.role == ROLE
@@ -270,6 +289,19 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
270289
assert tuner.estimator.hyperparameters()['num_components'] == '1'
271290

272291

292+
def test_attach_tuning_job_with_estimator_from_hyperparameters_with_early_stopping(sagemaker_session):
293+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
294+
job_details['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = 'Auto'
295+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
296+
return_value=job_details)
297+
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)
298+
299+
assert tuner.latest_tuning_job.name == JOB_NAME
300+
assert tuner.early_stopping_type == 'Auto'
301+
302+
assert isinstance(tuner.estimator, PCA)
303+
304+
273305
def test_attach_tuning_job_with_job_details(sagemaker_session):
274306
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
275307
HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session, job_details=job_details)

0 commit comments

Comments
 (0)