Skip to content

Commit a57ad8e

Browse files
repushkoAnton Repushko
authored and
Namrata Madan
committed
feature: support the Hyperband strategy with the StrategyConfig (aws#3440)
Co-authored-by: Anton Repushko <[email protected]>
1 parent 7ffa970 commit a57ad8e

File tree

2 files changed

+261
-13
lines changed

2 files changed

+261
-13
lines changed

src/sagemaker/tuner.py

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
6969
PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
7070
WARM_START_TYPE = "WarmStartType"
71+
HYPERBAND_STRATEGY_CONFIG = "HyperbandStrategyConfig"
72+
HYPERBAND_MIN_RESOURCE = "MinResource"
73+
HYPERBAND_MAX_RESOURCE = "MaxResource"
7174
GRID_SEARCH = "GridSearch"
7275

7376
logger = logging.getLogger(__name__)
@@ -207,6 +210,179 @@ def to_input_req(self):
207210
}
208211

209212

213+
class HyperbandStrategyConfig(object):
214+
"""The configuration for Hyperband, a multi-fidelity based hyperparameter tuning strategy.
215+
216+
Hyperband uses the final and intermediate results of a training job to dynamically allocate
217+
resources to hyperparameter configurations being evaluated while automatically stopping
218+
under-performing configurations. This parameter should be provided only if Hyperband is
219+
selected as the Strategy under the HyperParameterTuningJobConfig.
220+
221+
Examples:
222+
>>> hyperband_strategy_config = HyperbandStrategyConfig(
223+
>>> max_resource=10, min_resource = 1)
224+
>>> hyperband_strategy_config.max_resource
225+
10
226+
>>> hyperband_strategy_config.min_resource
227+
1
228+
"""
229+
230+
def __init__(self, max_resource: int, min_resource: int):
231+
"""Creates a ``HyperbandStrategyConfig`` with provided `min_resource`` and ``max_resource``.
232+
233+
Args:
234+
max_resource (int): The maximum number of resources (such as epochs) that can be used
235+
by a training job launched by a hyperparameter tuning job.
236+
Once a job reaches the MaxResource value, it is stopped.
237+
If a value for MaxResource is not provided, and Hyperband is selected as the
238+
hyperparameter tuning strategy, HyperbandTrainingJ attempts to infer MaxResource
239+
from the following keys (if present) in StaticsHyperParameters:
240+
epochs
241+
numepochs
242+
n-epochs
243+
n_epochs
244+
num_epochs
245+
If HyperbandStrategyConfig is unable to infer a value for MaxResource, it generates
246+
a validation error.
247+
The maximum value is 20,000 epochs. All metrics that correspond to an objective
248+
metric are used to derive early stopping decisions.
249+
For distributed training jobs, ensure that duplicate metrics are not printed in the
250+
logs across the individual nodes in a training job.
251+
If multiple nodes are publishing duplicate or incorrect metrics, hyperband
252+
optimisation algorithm may make an incorrect stopping decision and stop the job
253+
prematurely.
254+
min_resource (int): The minimum number of resources (such as epochs)
255+
that can be used by a training job launched by a hyperparameter tuning job.
256+
If the value for MinResource has not been reached, the training job will not be
257+
stopped by Hyperband.
258+
"""
259+
self.min_resource = min_resource
260+
self.max_resource = max_resource
261+
262+
@classmethod
263+
def from_job_desc(cls, hyperband_strategy_config):
264+
"""Creates a ``HyperbandStrategyConfig`` from a hyperband strategy configuration response.
265+
266+
This is the Hyperband strategy configuration from the DescribeTuningJob response.
267+
268+
Examples:
269+
>>> hyperband_strategy_config =
270+
>>> HyperbandStrategyConfig.from_job_desc(hyperband_strategy_config={
271+
>>> "MaxResource": 10,
272+
>>> "MinResource": 1
273+
>>> })
274+
>>> hyperband_strategy_config.max_resource
275+
10
276+
>>> hyperband_strategy_config.min_resource
277+
1
278+
279+
Args:
280+
hyperband_strategy_config (dict): The expected format of the
281+
``hyperband_strategy_config`` contains two first-class fields
282+
283+
Returns:
284+
sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of
285+
HyperbandStrategyConfig containing the max_resource and min_resource provided as part of
286+
``hyperband_strategy_config``.
287+
"""
288+
return cls(
289+
min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE],
290+
max_resource=hyperband_strategy_config[HYPERBAND_MAX_RESOURCE],
291+
)
292+
293+
def to_input_req(self):
294+
"""Converts the ``self`` instance to the desired input request format.
295+
296+
Examples:
297+
>>> hyperband_strategy_config = HyperbandStrategyConfig (
298+
max_resource=10,
299+
min_resource=1
300+
)
301+
>>> hyperband_strategy_config.to_input_req()
302+
{
303+
"MaxResource":10,
304+
"MinResource": 1
305+
}
306+
307+
Returns:
308+
dict: Containing the "MaxResource" and
309+
"MinResource" as the first class fields.
310+
"""
311+
return {
312+
HYPERBAND_MIN_RESOURCE: self.min_resource,
313+
HYPERBAND_MAX_RESOURCE: self.max_resource,
314+
}
315+
316+
317+
class StrategyConfig(object):
318+
"""The configuration for a training job launched by a hyperparameter tuning job.
319+
320+
Choose Bayesian for Bayesian optimization, and Random for random search optimization.
321+
For more advanced use cases, use Hyperband, which evaluates objective metrics for training jobs
322+
after every epoch.
323+
"""
324+
325+
def __init__(
326+
self,
327+
hyperband_strategy_config: HyperbandStrategyConfig,
328+
):
329+
"""Creates a ``StrategyConfig`` with provided ``HyperbandStrategyConfig``.
330+
331+
Args:
332+
hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration
333+
for the object that specifies the Hyperband strategy.
334+
This parameter is only supported for the Hyperband selection for Strategy within
335+
the HyperParameterTuningJobConfig.
336+
"""
337+
338+
self.hyperband_strategy_config = hyperband_strategy_config
339+
340+
@classmethod
341+
def from_job_desc(cls, strategy_config):
342+
"""Creates a ``HyperbandStrategyConfig`` from a hyperband strategy configuration response.
343+
344+
This is the hyper band strategy configuration from the DescribeTuningJob response.
345+
346+
Args:
347+
strategy_config (dict): The expected format of the
348+
``strategy_config`` contains one first-class field
349+
350+
Returns:
351+
sagemaker.tuner.StrategyConfig: De-serialized instance of
352+
StrategyConfig containing the strategy configuration.
353+
"""
354+
return cls(
355+
hyperband_strategy_config=HyperbandStrategyConfig.from_job_desc(
356+
strategy_config[HYPERBAND_STRATEGY_CONFIG]
357+
)
358+
)
359+
360+
def to_input_req(self):
361+
"""Converts the ``self`` instance to the desired input request format.
362+
363+
Examples:
364+
>>> strategy_config = StrategyConfig(
365+
HyperbandStrategyConfig(
366+
max_resource=10,
367+
min_resource=1
368+
)
369+
)
370+
>>> strategy_config.to_input_req()
371+
{
372+
"HyperbandStrategyConfig": {
373+
"MaxResource":10,
374+
"MinResource": 1
375+
}
376+
}
377+
378+
Returns:
379+
dict: Containing the strategy configurations.
380+
"""
381+
return {
382+
HYPERBAND_STRATEGY_CONFIG: self.hyperband_strategy_config.to_input_req(),
383+
}
384+
385+
210386
class HyperparameterTuner(object):
211387
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
212388
@@ -234,6 +410,7 @@ def __init__(
234410
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
235411
base_tuning_job_name: Optional[str] = None,
236412
warm_start_config: Optional[WarmStartConfig] = None,
413+
strategy_config: Optional[StrategyConfig] = None,
237414
early_stopping_type: Union[str, PipelineVariable] = "Off",
238415
estimator_name: Optional[str] = None,
239416
):
@@ -283,6 +460,8 @@ def __init__(
283460
warm_start_config (sagemaker.tuner.WarmStartConfig): A
284461
``WarmStartConfig`` object that has been initialized with the
285462
configuration defining the nature of warm start tuning job.
463+
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
464+
tuning job optimisation strategy.
286465
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
287466
enabled for the job. Can be either 'Auto' or 'Off' (default:
288467
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -321,6 +500,7 @@ def __init__(
321500
self._validate_parameter_ranges(estimator, hyperparameter_ranges)
322501

323502
self.strategy = strategy
503+
self.strategy_config = strategy_config
324504
self.objective_type = objective_type
325505
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
326506
# For all other strategies for the backward compatibility we keep
@@ -1295,6 +1475,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
12951475
objective_metric_name=self.objective_metric_name,
12961476
hyperparameter_ranges=self._hyperparameter_ranges,
12971477
strategy=self.strategy,
1478+
strategy_config=self.strategy_config,
12981479
objective_type=self.objective_type,
12991480
max_jobs=self.max_jobs,
13001481
max_parallel_jobs=self.max_parallel_jobs,
@@ -1321,6 +1502,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
13211502
hyperparameter_ranges_dict=self._hyperparameter_ranges_dict,
13221503
metric_definitions_dict=self.metric_definitions_dict,
13231504
strategy=self.strategy,
1505+
strategy_config=self.strategy_config,
13241506
objective_type=self.objective_type,
13251507
max_jobs=self.max_jobs,
13261508
max_parallel_jobs=self.max_parallel_jobs,
@@ -1337,6 +1519,7 @@ def create(
13371519
metric_definitions_dict=None,
13381520
base_tuning_job_name=None,
13391521
strategy="Bayesian",
1522+
strategy_config=None,
13401523
objective_type="Maximize",
13411524
max_jobs=None,
13421525
max_parallel_jobs=1,
@@ -1380,11 +1563,13 @@ def create(
13801563
metric from the logs. This should be defined only for hyperparameter tuning jobs
13811564
that don't use an Amazon algorithm.
13821565
base_tuning_job_name (str): Prefix for the hyperparameter tuning job name when the
1383-
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified,
1384-
a default job name is generated, based on the training image name and current
1385-
timestamp.
1566+
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches.
1567+
If not specified, a default job name is generated,
1568+
based on the training image name and current timestamp.
13861569
strategy (str): Strategy to be used for hyperparameter estimations
13871570
(default: 'Bayesian').
1571+
strategy_config (dict): The configuration for a training job launched by a
1572+
hyperparameter tuning job.
13881573
objective_type (str): The type of the objective metric for evaluating training jobs.
13891574
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
13901575
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
@@ -1432,6 +1617,7 @@ def create(
14321617
hyperparameter_ranges=hyperparameter_ranges_dict[first_estimator_name],
14331618
metric_definitions=metric_definitions,
14341619
strategy=strategy,
1620+
strategy_config=strategy_config,
14351621
objective_type=objective_type,
14361622
max_jobs=max_jobs,
14371623
max_parallel_jobs=max_parallel_jobs,
@@ -1589,6 +1775,9 @@ def _get_tuner_args(cls, tuner, inputs):
15891775
"early_stopping_type": tuner.early_stopping_type,
15901776
}
15911777

1778+
if tuner.strategy_config is not None:
1779+
tuning_config["strategy_config"] = tuner.strategy_config
1780+
15921781
if tuner.objective_metric_name is not None:
15931782
tuning_config["objective_type"] = tuner.objective_type
15941783
tuning_config["objective_metric_name"] = tuner.objective_metric_name

0 commit comments

Comments
 (0)