Skip to content

Commit 0f785dd

Browse files
author
Anton Repushko
committed
feature: add support of the intelligent stopping in the tuner
1 parent e2f3888 commit 0f785dd

File tree

4 files changed

+172
-1
lines changed

4 files changed

+172
-1
lines changed

src/sagemaker/session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,6 +2178,7 @@ def tune( # noqa: C901
21782178
objective_metric_name,
21792179
max_jobs,
21802180
max_parallel_jobs,
2181+
max_runtime_in_seconds,
21812182
parameter_ranges,
21822183
static_hyperparameters,
21832184
input_mode,
@@ -2190,6 +2191,7 @@ def tune( # noqa: C901
21902191
tags,
21912192
warm_start_config,
21922193
strategy_config=None,
2194+
completion_criteria_config=None,
21932195
enable_network_isolation=False,
21942196
image_uri=None,
21952197
algorithm_arn=None,
@@ -2215,6 +2217,8 @@ def tune( # noqa: C901
22152217
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
22162218
tuning job.
22172219
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
2220+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2221+
that a training job launched by a hyperparameter tuning job can run.
22182222
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be
22192223
one of three types: Continuous, Integer, or Categorical.
22202224
static_hyperparameters (dict): Hyperparameters for model training. These
@@ -2255,6 +2259,8 @@ def tune( # noqa: C901
22552259
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
22562260
warm_start_config (dict): Configuration defining the type of warm start and
22572261
other required configurations.
2262+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A configuration
2263+
for the completion criteria.
22582264
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
22592265
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
22602266
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
@@ -2294,12 +2300,14 @@ def tune( # noqa: C901
22942300
strategy=strategy,
22952301
max_jobs=max_jobs,
22962302
max_parallel_jobs=max_parallel_jobs,
2303+
max_runtime_in_seconds=max_runtime_in_seconds,
22972304
objective_type=objective_type,
22982305
objective_metric_name=objective_metric_name,
22992306
parameter_ranges=parameter_ranges,
23002307
early_stopping_type=early_stopping_type,
23012308
random_seed=random_seed,
23022309
strategy_config=strategy_config,
2310+
completion_criteria_config=completion_criteria_config,
23032311
),
23042312
"TrainingJobDefinition": self._map_training_config(
23052313
static_hyperparameters=static_hyperparameters,
@@ -2452,12 +2460,14 @@ def _map_tuning_config(
24522460
strategy,
24532461
max_jobs,
24542462
max_parallel_jobs,
2463+
max_runtime_in_seconds=None,
24552464
early_stopping_type="Off",
24562465
objective_type=None,
24572466
objective_metric_name=None,
24582467
parameter_ranges=None,
24592468
random_seed=None,
24602469
strategy_config=None,
2470+
completion_criteria_config=None,
24612471
):
24622472
"""Construct tuning job configuration dictionary.
24632473
@@ -2466,6 +2476,8 @@ def _map_tuning_config(
24662476
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
24672477
tuning job.
24682478
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
2479+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2480+
that a training job launched by a hyperparameter tuning job can run.
24692481
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
24702482
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
24712483
attempted. If set to 'Auto', early stopping of some training jobs may happen,
@@ -2480,6 +2492,8 @@ def _map_tuning_config(
24802492
produce more consistent configurations for the same tuning job.
24812493
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
24822494
strategy.
2495+
completion_criteria_config (dict): A configuration
2496+
for the completion criteria.
24832497
24842498
Returns:
24852499
A dictionary of tuning job configuration. For format details, please refer to
@@ -2496,6 +2510,9 @@ def _map_tuning_config(
24962510
"TrainingJobEarlyStoppingType": early_stopping_type,
24972511
}
24982512

2513+
if max_runtime_in_seconds is not None:
2514+
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds
2515+
24992516
if random_seed is not None:
25002517
tuning_config["RandomSeed"] = random_seed
25012518

@@ -2508,6 +2525,9 @@ def _map_tuning_config(
25082525

25092526
if strategy_config is not None:
25102527
tuning_config["StrategyConfig"] = strategy_config
2528+
2529+
if completion_criteria_config is not None:
2530+
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
25112531
return tuning_config
25122532

25132533
@classmethod

src/sagemaker/tuner.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@
7272
HYPERBAND_MIN_RESOURCE = "MinResource"
7373
HYPERBAND_MAX_RESOURCE = "MaxResource"
7474
GRID_SEARCH = "GridSearch"
75+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING = "MaxNumberOfTrainingJobsNotImproving"
76+
BEST_OBJECTIVE_NOT_IMPROVING = "BestObjectiveNotImproving"
77+
CONVERGENCE_DETECTED = "ConvergenceDetected"
78+
COMPLETE_ON_CONVERGENCE_DETECTED = "CompleteOnConvergence"
79+
TARGET_OBJECTIVE_METRIC_VALUE = "TargetObjectiveMetricValue"
80+
MAX_RUNTIME_IN_SECONDS = "MaxRuntimeInSeconds"
7581

7682
logger = logging.getLogger(__name__)
7783

@@ -383,6 +389,109 @@ def to_input_req(self):
383389
}
384390

385391

392+
class TuningJobCompletionCriteriaConfig(object):
393+
"""The configuration for a job completion criteria.
394+
"""
395+
396+
def __init__(
397+
self,
398+
max_number_of_training_jobs_not_improving: int = None,
399+
complete_on_convergence: bool = None,
400+
target_objective_metric_value: float = None
401+
):
402+
"""Creates a ``TuningJobCompletionCriteriaConfig`` with provided criteria.
403+
404+
Args:
405+
max_number_of_training_jobs_not_improving (int): The number of training jobs that have failed
406+
to improve model performance by 1% or greater over prior training jobs as evaluated
407+
against an objective function.
408+
complete_on_convergence (bool): A flag to top your hyperparameter tuning job if
409+
automatic model tuning (AMT) has detected that your model has converged as evaluated against
410+
your objective function.
411+
target_objective_metric_value (float): The value of the objective metric.
412+
"""
413+
414+
self.max_number_of_training_jobs_not_improving = max_number_of_training_jobs_not_improving
415+
self.complete_on_convergence = complete_on_convergence
416+
self.target_objective_metric_value = target_objective_metric_value
417+
418+
@classmethod
419+
def from_job_desc(cls, completion_criteria_config):
420+
"""Creates a ``TuningJobCompletionCriteriaConfig`` from a tuning job completion criteria configuration response.
421+
422+
This is the completion criteria configuration from the DescribeTuningJob response.
423+
424+
Args:
425+
completion_criteria_config (dict): The expected format of the
426+
``completion_criteria_config`` contains three first-class fields
427+
428+
Returns:
429+
sagemaker.tuner.TuningJobCompletionCriteriaConfig: De-serialized instance of
430+
TuningJobCompletionCriteriaConfig containing the completion criteria.
431+
"""
432+
complete_on_convergence = None
433+
if completion_criteria_config[CONVERGENCE_DETECTED]:
434+
if completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED]:
435+
complete_on_convergence = \
436+
True if completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] == 'Enabled'\
437+
else False
438+
439+
max_number_of_training_jobs_not_improving = None
440+
if completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING]:
441+
if completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING] \
442+
[MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING]:
443+
max_number_of_training_jobs_not_improving = \
444+
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
445+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING]
446+
447+
target_objective_metric_value = None
448+
if completion_criteria_config[TARGET_OBJECTIVE_METRIC_VALUE]:
449+
target_objective_metric_value = completion_criteria_config[TARGET_OBJECTIVE_METRIC_VALUE]
450+
451+
return cls(
452+
max_number_of_training_jobs_not_improving=max_number_of_training_jobs_not_improving,
453+
complete_on_convergence=complete_on_convergence,
454+
target_objective_metric_value=target_objective_metric_value
455+
)
456+
457+
def to_input_req(self):
458+
"""Converts the ``self`` instance to the desired input request format.
459+
460+
Examples:
461+
>>> completion_criteria_config = TuningJobCompletionCriteriaConfig(
462+
max_number_of_training_jobs_not_improving=5
463+
complete_on_convergence = True,
464+
target_objective_metric_value = 0.42
465+
)
466+
>>> completion_criteria_config.to_input_req()
467+
{
468+
"BestObjectiveNotImproving": {
469+
"MaxNumberOfTrainingJobsNotImproving":5
470+
},
471+
"ConvergenceDetected": {
472+
"CompleteOnConvergence": "Enabled",
473+
},
474+
"TargetObjectiveMetricValue": 0.42
475+
}
476+
477+
Returns:
478+
dict: Containing the completion criteria configurations.
479+
"""
480+
completion_criteria_config = {}
481+
if self.max_number_of_training_jobs_not_improving is not None:
482+
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING] =\
483+
self.max_number_of_training_jobs_not_improving
484+
485+
if self.target_objective_metric_value is not None:
486+
completion_criteria_config[TARGET_OBJECTIVE_METRIC_VALUE] = self.target_objective_metric_value
487+
488+
if self.complete_on_convergence is not None:
489+
completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] = \
490+
'Enabled' if self.complete_on_convergence else 'Disabled'
491+
492+
return completion_criteria_config
493+
494+
386495
class HyperparameterTuner(object):
387496
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388497
@@ -407,10 +516,12 @@ def __init__(
407516
objective_type: Union[str, PipelineVariable] = "Maximize",
408517
max_jobs: Union[int, PipelineVariable] = None,
409518
max_parallel_jobs: Union[int, PipelineVariable] = 1,
519+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
410520
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
411521
base_tuning_job_name: Optional[str] = None,
412522
warm_start_config: Optional[WarmStartConfig] = None,
413523
strategy_config: Optional[StrategyConfig] = None,
524+
completion_criteria_config: Optional[TuningJobCompletionCriteriaConfig] = None,
414525
early_stopping_type: Union[str, PipelineVariable] = "Off",
415526
estimator_name: Optional[str] = None,
416527
random_seed: Optional[int] = None,
@@ -450,6 +561,8 @@ def __init__(
450561
strategy and the default value is 1 for all others strategies (default: None).
451562
max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
452563
start (default: 1).
564+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
565+
that a training job launched by a hyperparameter tuning job can run.
453566
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
454567
labeling the tuning job (default: None). For more, see
455568
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
@@ -463,6 +576,8 @@ def __init__(
463576
configuration defining the nature of warm start tuning job.
464577
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
465578
tuning job optimisation strategy.
579+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A configuration
580+
for the completion criteria.
466581
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
467582
enabled for the job. Can be either 'Auto' or 'Off' (default:
468583
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -505,6 +620,7 @@ def __init__(
505620

506621
self.strategy = strategy
507622
self.strategy_config = strategy_config
623+
self.completion_criteria_config = completion_criteria_config
508624
self.objective_type = objective_type
509625
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
510626
# For all other strategies for the backward compatibility we keep
@@ -513,6 +629,7 @@ def __init__(
513629
if max_jobs is None and strategy is not GRID_SEARCH:
514630
self.max_jobs = 1
515631
self.max_parallel_jobs = max_parallel_jobs
632+
self.max_runtime_in_seconds = max_runtime_in_seconds
516633

517634
self.tags = tags
518635
self.base_tuning_job_name = base_tuning_job_name
@@ -1227,6 +1344,9 @@ def _prepare_init_params_from_job_description(cls, job_details):
12271344
"base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]),
12281345
}
12291346

1347+
if MAX_RUNTIME_IN_SECONDS in tuning_config["ResourceLimits"]:
1348+
params["max_runtime_in_seconds"] = tuning_config["ResourceLimits"][MAX_RUNTIME_IN_SECONDS]
1349+
12301350
if "RandomSeed" in tuning_config:
12311351
params["random_seed"] = tuning_config["RandomSeed"]
12321352

@@ -1484,9 +1604,11 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
14841604
hyperparameter_ranges=self._hyperparameter_ranges,
14851605
strategy=self.strategy,
14861606
strategy_config=self.strategy_config,
1607+
completion_criteria_config=self.completion_criteria_config,
14871608
objective_type=self.objective_type,
14881609
max_jobs=self.max_jobs,
14891610
max_parallel_jobs=self.max_parallel_jobs,
1611+
max_runtime_in_seconds=self.max_runtime_in_seconds,
14901612
warm_start_config=WarmStartConfig(
14911613
warm_start_type=warm_start_type, parents=all_parents
14921614
),
@@ -1512,9 +1634,11 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
15121634
metric_definitions_dict=self.metric_definitions_dict,
15131635
strategy=self.strategy,
15141636
strategy_config=self.strategy_config,
1637+
completion_criteria_config=self.completion_criteria_config,
15151638
objective_type=self.objective_type,
15161639
max_jobs=self.max_jobs,
15171640
max_parallel_jobs=self.max_parallel_jobs,
1641+
max_runtime_in_seconds=self.max_runtime_in_seconds,
15181642
warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents),
15191643
early_stopping_type=self.early_stopping_type,
15201644
random_seed=self.random_seed,
@@ -1530,9 +1654,11 @@ def create(
15301654
base_tuning_job_name=None,
15311655
strategy="Bayesian",
15321656
strategy_config=None,
1657+
completion_criteria_config=None,
15331658
objective_type="Maximize",
15341659
max_jobs=None,
15351660
max_parallel_jobs=1,
1661+
max_runtime_in_seconds=None,
15361662
tags=None,
15371663
warm_start_config=None,
15381664
early_stopping_type="Off",
@@ -1581,13 +1707,16 @@ def create(
15811707
(default: 'Bayesian').
15821708
strategy_config (dict): The configuration for a training job launched by a
15831709
hyperparameter tuning job.
1710+
completion_criteria_config (dict): The configuration for tuning job completion criteria.
15841711
objective_type (str): The type of the objective metric for evaluating training jobs.
15851712
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
15861713
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
15871714
tuning job. The default value is unspecified fot the GridSearch strategy
15881715
and the value is 1 for all others strategies (default: None).
15891716
max_parallel_jobs (int): Maximum number of parallel training jobs to start
15901717
(default: 1).
1718+
max_runtime_in_seconds (int): The maximum time in seconds
1719+
that a training job launched by a hyperparameter tuning job can run.
15911720
tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,
15921721
see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
15931722
warm_start_config (sagemaker.tuner.WarmStartConfig): A ``WarmStartConfig`` object that
@@ -1632,9 +1761,11 @@ def create(
16321761
metric_definitions=metric_definitions,
16331762
strategy=strategy,
16341763
strategy_config=strategy_config,
1764+
completion_criteria_config=completion_criteria_config,
16351765
objective_type=objective_type,
16361766
max_jobs=max_jobs,
16371767
max_parallel_jobs=max_parallel_jobs,
1768+
max_runtime_in_seconds=max_runtime_in_seconds,
16381769
tags=tags,
16391770
warm_start_config=warm_start_config,
16401771
early_stopping_type=early_stopping_type,
@@ -1790,6 +1921,9 @@ def _get_tuner_args(cls, tuner, inputs):
17901921
"early_stopping_type": tuner.early_stopping_type,
17911922
}
17921923

1924+
if tuner.max_runtime_in_seconds is not None:
1925+
tuning_config["max_runtime_in_seconds"] = tuner.max_runtime_in_seconds
1926+
17931927
if tuner.random_seed is not None:
17941928
tuning_config["random_seed"] = tuner.random_seed
17951929

@@ -1804,6 +1938,9 @@ def _get_tuner_args(cls, tuner, inputs):
18041938
if parameter_ranges is not None:
18051939
tuning_config["parameter_ranges"] = parameter_ranges
18061940

1941+
if tuner.completion_criteria_config is not None:
1942+
tuning_config["completion_criteria_config"] = tuner.completion_criteria_config.to_input_req()
1943+
18071944
tuner_args = {
18081945
"job_name": tuner._current_job_name,
18091946
"tuning_config": tuning_config,

tests/unit/test_tuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,17 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
543543
assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
544544
assert tuner.max_jobs == 1
545545
assert tuner.max_parallel_jobs == 1
546+
assert tuner.max_runtime_in_seconds == 1
546547
assert tuner.metric_definitions == METRIC_DEFINITIONS
547548
assert tuner.strategy == "Bayesian"
548549
assert tuner.objective_type == "Minimize"
549550
assert tuner.early_stopping_type == "Off"
550551
assert tuner.random_seed == 0
551552

553+
assert tuner.completion_criteria_config.complete_on_convergence == True
554+
assert tuner.completion_criteria_config.target_objective_metric_value == 0.42
555+
assert tuner.completion_criteria_config.max_number_of_training_jobs_not_improving == 5
556+
552557
assert isinstance(tuner.estimator, PCA)
553558
assert tuner.estimator.role == ROLE
554559
assert tuner.estimator.instance_count == 1

0 commit comments

Comments
 (0)