Skip to content

Commit 6729e82

Browse files
author
Anton Repushko
committed
feature: support the Hyperband strategy with the StrategyConfig
1 parent ecb4ac2 commit 6729e82

File tree

2 files changed

+300
-22
lines changed

2 files changed

+300
-22
lines changed

src/sagemaker/tuner.py

Lines changed: 231 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
from sagemaker.estimator import Framework, EstimatorBase
3434
from sagemaker.inputs import TrainingInput, FileSystemInput
3535
from sagemaker.job import _Job
36-
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
36+
from sagemaker.jumpstart.utils import (
37+
add_jumpstart_tags,
38+
get_jumpstart_base_name_if_jumpstart_model,
39+
)
3740
from sagemaker.parameter import (
3841
CategoricalParameter,
3942
ContinuousParameter,
@@ -44,7 +47,12 @@
4447
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
4548

4649
from sagemaker.session import Session
47-
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, to_string
50+
from sagemaker.utils import (
51+
base_from_name,
52+
base_name_from_image,
53+
name_from_base,
54+
to_string,
55+
)
4856

4957
AMAZON_ESTIMATOR_MODULE = "sagemaker"
5058
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -60,6 +68,9 @@
6068
HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
6169
PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
6270
WARM_START_TYPE = "WarmStartType"
71+
HYPERBAND_STRATEGY_CONFIG = "HyperbandStrategyConfig"
72+
HYPERBAND_MIN_RESOURCE = "MinResource"
73+
HYPERBAND_MAX_RESOURCE = "MaxResource"
6374

6475
logger = logging.getLogger(__name__)
6576

@@ -165,7 +176,8 @@ def from_job_desc(cls, warm_start_config):
165176
parents.append(parent[HYPERPARAMETER_TUNING_JOB_NAME])
166177

167178
return cls(
168-
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]), parents=parents
179+
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]),
180+
parents=parents,
169181
)
170182

171183
def to_input_req(self):
@@ -197,6 +209,179 @@ def to_input_req(self):
197209
}
198210

199211

212+
class HyperbandStrategyConfig(object):
213+
"""The configuration for Hyperband, a multi-fidelity based hyperparameter tuning strategy.
214+
215+
Hyperband uses the final and intermediate results of a training job to dynamically allocate
216+
resources to hyperparameter configurations being evaluated while automatically stopping
217+
under-performing configurations. This parameter should be provided only if Hyperband is
218+
selected as the Strategy under the HyperParameterTuningJobConfig.
219+
220+
Examples:
221+
>>> hyperband_strategy_config = HyperbandStrategyConfig(
222+
>>> max_resource=10, min_resource = 1)
223+
>>> hyperband_strategy_config.max_resource
224+
10
225+
>>> hyperband_strategy_config.min_resource
226+
1
227+
"""
228+
229+
def __init__(self, max_resource: int, min_resource: int):
230+
"""Creates a ``HyperbandStrategyConfig`` with provided `min_resource`` and ``max_resource``.
231+
232+
Args:
233+
max_resource (int): The maximum number of resources (such as epochs) that can be used
234+
by a training job launched by a hyperparameter tuning job.
235+
Once a job reaches the MaxResource value, it is stopped.
236+
If a value for MaxResource is not provided, and Hyperband is selected as the
237+
hyperparameter tuning strategy, HyperbandTrainingJ attempts to infer MaxResource
238+
from the following keys (if present) in StaticsHyperParameters:
239+
epochs
240+
numepochs
241+
n-epochs
242+
n_epochs
243+
num_epochs
244+
If HyperbandStrategyConfig is unable to infer a value for MaxResource, it generates
245+
a validation error.
246+
The maximum value is 20,000 epochs. All metrics that correspond to an objective
247+
metric are used to derive early stopping decisions.
248+
For distributed training jobs, ensure that duplicate metrics are not printed in the
249+
logs across the individual nodes in a training job.
250+
If multiple nodes are publishing duplicate or incorrect metrics, hyperband
251+
optimisation algorithm may make an incorrect stopping decision and stop the job
252+
prematurely.
253+
min_resource (int): The minimum number of resources (such as epochs)
254+
that can be used by a training job launched by a hyperparameter tuning job.
255+
If the value for MinResource has not been reached, the training job will not be
256+
stopped by Hyperband.
257+
"""
258+
self.min_resource = min_resource
259+
self.max_resource = max_resource
260+
261+
@classmethod
262+
def from_job_desc(cls, hyperband_strategy_config):
263+
"""Creates a ``HyperbandStrategyConfig`` from a hyperband strategy configuration response.
264+
265+
This is the Hyperband strategy configuration from the DescribeTuningJob response.
266+
267+
Examples:
268+
>>> hyperband_strategy_config =
269+
>>> HyperbandStrategyConfig.from_job_desc(hyperband_strategy_config={
270+
>>> "MaxResource": 10,
271+
>>> "MinResource": 1
272+
>>> })
273+
>>> hyperband_strategy_config.max_resource
274+
10
275+
>>> hyperband_strategy_config.min_resource
276+
1
277+
278+
Args:
279+
hyperband_strategy_config (dict): The expected format of the
280+
``hyperband_strategy_config`` contains two first-class fields
281+
282+
Returns:
283+
sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of
284+
HyperbandStrategyConfig containing the max_resource and min_resource provided as part of
285+
``hyperband_strategy_config``.
286+
"""
287+
return cls(
288+
min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE],
289+
max_resource=hyperband_strategy_config[HYPERBAND_MAX_RESOURCE],
290+
)
291+
292+
def to_input_req(self):
293+
"""Converts the ``self`` instance to the desired input request format.
294+
295+
Examples:
296+
>>> hyperband_strategy_config = HyperbandStrategyConfig (
297+
max_resource=10,
298+
min_resource=1
299+
)
300+
>>> hyperband_strategy_config.to_input_req()
301+
{
302+
"MaxResource":10,
303+
"MinResource": 1
304+
}
305+
306+
Returns:
307+
dict: Containing the "MaxResource" and
308+
"MinResource" as the first class fields.
309+
"""
310+
return {
311+
HYPERBAND_MIN_RESOURCE: self.min_resource,
312+
HYPERBAND_MAX_RESOURCE: self.max_resource,
313+
}
314+
315+
316+
class StrategyConfig(object):
317+
"""The configuration for a training job launched by a hyperparameter tuning job.
318+
319+
Choose Bayesian for Bayesian optimization, and Random for random search optimization.
320+
For more advanced use cases, use Hyperband, which evaluates objective metrics for training jobs
321+
after every epoch.
322+
"""
323+
324+
def __init__(
325+
self,
326+
hyperband_strategy_config: HyperbandStrategyConfig,
327+
):
328+
"""Creates a ``StrategyConfig`` with provided ``HyperbandStrategyConfig``.
329+
330+
Args:
331+
hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration
332+
for the object that specifies the Hyperband strategy.
333+
This parameter is only supported for the Hyperband selection for Strategy within
334+
the HyperParameterTuningJobConfig.
335+
"""
336+
337+
self.hyperband_strategy_config = hyperband_strategy_config
338+
339+
@classmethod
340+
def from_job_desc(cls, strategy_config):
341+
"""Creates a ``HyperbandStrategyConfig`` from a hyperband strategy configuration response.
342+
343+
This is the hyper band strategy configuration from the DescribeTuningJob response.
344+
345+
Args:
346+
strategy_config (dict): The expected format of the
347+
``strategy_config`` contains one first-class field
348+
349+
Returns:
350+
sagemaker.tuner.StrategyConfig: De-serialized instance of
351+
StrategyConfig containing the strategy configuration.
352+
"""
353+
return cls(
354+
hyperband_strategy_config=HyperbandStrategyConfig.from_job_desc(
355+
strategy_config[HYPERBAND_STRATEGY_CONFIG]
356+
)
357+
)
358+
359+
def to_input_req(self):
360+
"""Converts the ``self`` instance to the desired input request format.
361+
362+
Examples:
363+
>>> strategy_config = StrategyConfig(
364+
HyperbandStrategyConfig(
365+
max_resource=10,
366+
min_resource=1
367+
)
368+
)
369+
>>> strategy_config.to_input_req()
370+
{
371+
"HyperbandStrategyConfig": {
372+
"MaxResource":10,
373+
"MinResource": 1
374+
}
375+
}
376+
377+
Returns:
378+
dict: Containing the strategy configurations.
379+
"""
380+
return {
381+
HYPERBAND_STRATEGY_CONFIG: self.hyperband_strategy_config.to_input_req(),
382+
}
383+
384+
200385
class HyperparameterTuner(object):
201386
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
202387
@@ -224,6 +409,7 @@ def __init__(
224409
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
225410
base_tuning_job_name: Optional[str] = None,
226411
warm_start_config: Optional[WarmStartConfig] = None,
412+
strategy_config: Optional[StrategyConfig] = None,
227413
early_stopping_type: Union[str, PipelineVariable] = "Off",
228414
estimator_name: Optional[str] = None,
229415
):
@@ -272,6 +458,8 @@ def __init__(
272458
warm_start_config (sagemaker.tuner.WarmStartConfig): A
273459
``WarmStartConfig`` object that has been initialized with the
274460
configuration defining the nature of warm start tuning job.
461+
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
462+
tuning job optimisation strategy.
275463
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
276464
enabled for the job. Can be either 'Auto' or 'Off' (default:
277465
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -310,6 +498,7 @@ def __init__(
310498
self._validate_parameter_ranges(estimator, hyperparameter_ranges)
311499

312500
self.strategy = strategy
501+
self.strategy_config = strategy_config
313502
self.objective_type = objective_type
314503
self.max_jobs = max_jobs
315504
self.max_parallel_jobs = max_parallel_jobs
@@ -373,7 +562,8 @@ def _prepare_job_name_for_tuning(self, job_name=None):
373562
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
374563
)
375564
base_name = base_name_from_image(
376-
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
565+
estimator.training_image_uri(),
566+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
377567
)
378568

379569
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
@@ -434,7 +624,15 @@ def _prepare_static_hyperparameters(
434624
def fit(
435625
self,
436626
inputs: Optional[
437-
Union[str, Dict, List, TrainingInput, FileSystemInput, RecordSet, FileSystemRecordSet]
627+
Union[
628+
str,
629+
Dict,
630+
List,
631+
TrainingInput,
632+
FileSystemInput,
633+
RecordSet,
634+
FileSystemRecordSet,
635+
]
438636
] = None,
439637
job_name: Optional[str] = None,
440638
include_cls_metadata: Union[bool, Dict[str, bool]] = False,
@@ -524,7 +722,9 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
524722
allowed_keys=estimator_names,
525723
)
526724
self._validate_dict_argument(
527-
name="estimator_kwargs", value=estimator_kwargs, allowed_keys=estimator_names
725+
name="estimator_kwargs",
726+
value=estimator_kwargs,
727+
allowed_keys=estimator_names,
528728
)
529729

530730
for (estimator_name, estimator) in self.estimator_dict.items():
@@ -546,7 +746,13 @@ def _prepare_estimator_for_tuning(cls, estimator, inputs, job_name, **kwargs):
546746
estimator._prepare_for_training(job_name)
547747

548748
@classmethod
549-
def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estimator_cls=None):
749+
def attach(
750+
cls,
751+
tuning_job_name,
752+
sagemaker_session=None,
753+
job_details=None,
754+
estimator_cls=None,
755+
):
550756
"""Attach to an existing hyperparameter tuning job.
551757
552758
Create a HyperparameterTuner bound to an existing hyperparameter
@@ -959,7 +1165,8 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
9591165

9601166
# Default to the BYO estimator
9611167
return getattr(
962-
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME
1168+
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE),
1169+
cls.DEFAULT_ESTIMATOR_CLS_NAME,
9631170
)
9641171

9651172
@classmethod
@@ -1151,7 +1358,10 @@ def _validate_parameter_ranges(self, estimator, hyperparameter_ranges):
11511358

11521359
def _validate_parameter_range(self, value_hp, parameter_range):
11531360
"""Placeholder docstring"""
1154-
for (parameter_range_key, parameter_range_value) in parameter_range.__dict__.items():
1361+
for (
1362+
parameter_range_key,
1363+
parameter_range_value,
1364+
) in parameter_range.__dict__.items():
11551365
if parameter_range_key == "scaling_type":
11561366
continue
11571367

@@ -1258,6 +1468,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
12581468
objective_metric_name=self.objective_metric_name,
12591469
hyperparameter_ranges=self._hyperparameter_ranges,
12601470
strategy=self.strategy,
1471+
strategy_config=self.strategy_config,
12611472
objective_type=self.objective_type,
12621473
max_jobs=self.max_jobs,
12631474
max_parallel_jobs=self.max_parallel_jobs,
@@ -1284,6 +1495,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
12841495
hyperparameter_ranges_dict=self._hyperparameter_ranges_dict,
12851496
metric_definitions_dict=self.metric_definitions_dict,
12861497
strategy=self.strategy,
1498+
strategy_config=self.strategy_config,
12871499
objective_type=self.objective_type,
12881500
max_jobs=self.max_jobs,
12891501
max_parallel_jobs=self.max_parallel_jobs,
@@ -1300,6 +1512,7 @@ def create(
13001512
metric_definitions_dict=None,
13011513
base_tuning_job_name=None,
13021514
strategy="Bayesian",
1515+
strategy_config=None,
13031516
objective_type="Maximize",
13041517
max_jobs=1,
13051518
max_parallel_jobs=1,
@@ -1343,11 +1556,13 @@ def create(
13431556
metric from the logs. This should be defined only for hyperparameter tuning jobs
13441557
that don't use an Amazon algorithm.
13451558
base_tuning_job_name (str): Prefix for the hyperparameter tuning job name when the
1346-
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified,
1347-
a default job name is generated, based on the training image name and current
1348-
timestamp.
1559+
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches.
1560+
If not specified, a default job name is generated,
1561+
based on the training image name and current timestamp.
13491562
strategy (str): Strategy to be used for hyperparameter estimations
13501563
(default: 'Bayesian').
1564+
strategy_config (dict): The configuration for a training job launched by a
1565+
hyperparameter tuning job.
13511566
objective_type (str): The type of the objective metric for evaluating training jobs.
13521567
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
13531568
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
@@ -1394,6 +1609,7 @@ def create(
13941609
hyperparameter_ranges=hyperparameter_ranges_dict[first_estimator_name],
13951610
metric_definitions=metric_definitions,
13961611
strategy=strategy,
1612+
strategy_config=strategy_config,
13971613
objective_type=objective_type,
13981614
max_jobs=max_jobs,
13991615
max_parallel_jobs=max_parallel_jobs,
@@ -1551,6 +1767,9 @@ def _get_tuner_args(cls, tuner, inputs):
15511767
"early_stopping_type": tuner.early_stopping_type,
15521768
}
15531769

1770+
if tuner.strategy_config is not None:
1771+
tuning_config["strategy_config"] = tuner.strategy_config
1772+
15541773
if tuner.objective_metric_name is not None:
15551774
tuning_config["objective_type"] = tuner.objective_type
15561775
tuning_config["objective_metric_name"] = tuner.objective_metric_name

0 commit comments

Comments
 (0)