Skip to content

Commit cca24a5

Browse files
author
Anton Repushko
committed
feature: add support of the intelligent stopping in the tuner
1 parent d2d377f commit cca24a5

File tree

4 files changed

+147
-70
lines changed

4 files changed

+147
-70
lines changed

src/sagemaker/session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,7 +2189,9 @@ def tune( # noqa: C901
21892189
stop_condition,
21902190
tags,
21912191
warm_start_config,
2192+
max_runtime_in_seconds=None,
21922193
strategy_config=None,
2194+
completion_criteria_config=None,
21932195
enable_network_isolation=False,
21942196
image_uri=None,
21952197
algorithm_arn=None,
@@ -2256,6 +2258,10 @@ def tune( # noqa: C901
22562258
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
22572259
warm_start_config (dict): Configuration defining the type of warm start and
22582260
other required configurations.
2261+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2262+
that a training job launched by a hyperparameter tuning job can run.
2263+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
2264+
configuration for the completion criteria.
22592265
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
22602266
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
22612267
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
@@ -2311,12 +2317,14 @@ def tune( # noqa: C901
23112317
strategy=strategy,
23122318
max_jobs=max_jobs,
23132319
max_parallel_jobs=max_parallel_jobs,
2320+
max_runtime_in_seconds=max_runtime_in_seconds,
23142321
objective_type=objective_type,
23152322
objective_metric_name=objective_metric_name,
23162323
parameter_ranges=parameter_ranges,
23172324
early_stopping_type=early_stopping_type,
23182325
random_seed=random_seed,
23192326
strategy_config=strategy_config,
2327+
completion_criteria_config=completion_criteria_config,
23202328
),
23212329
"TrainingJobDefinition": self._map_training_config(
23222330
static_hyperparameters=static_hyperparameters,
@@ -2470,12 +2478,14 @@ def _map_tuning_config(
24702478
strategy,
24712479
max_jobs,
24722480
max_parallel_jobs,
2481+
max_runtime_in_seconds=None,
24732482
early_stopping_type="Off",
24742483
objective_type=None,
24752484
objective_metric_name=None,
24762485
parameter_ranges=None,
24772486
random_seed=None,
24782487
strategy_config=None,
2488+
completion_criteria_config=None,
24792489
):
24802490
"""Construct tuning job configuration dictionary.
24812491
@@ -2484,6 +2494,8 @@ def _map_tuning_config(
24842494
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
24852495
tuning job.
24862496
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
2497+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2498+
that a training job launched by a hyperparameter tuning job can run.
24872499
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
24882500
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
24892501
attempted. If set to 'Auto', early stopping of some training jobs may happen,
@@ -2498,6 +2510,8 @@ def _map_tuning_config(
24982510
produce more consistent configurations for the same tuning job.
24992511
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
25002512
strategy.
2513+
completion_criteria_config (dict): A configuration
2514+
for the completion criteria.
25012515
25022516
Returns:
25032517
A dictionary of tuning job configuration. For format details, please refer to
@@ -2514,6 +2528,9 @@ def _map_tuning_config(
25142528
"TrainingJobEarlyStoppingType": early_stopping_type,
25152529
}
25162530

2531+
if max_runtime_in_seconds is not None:
2532+
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds
2533+
25172534
if random_seed is not None:
25182535
tuning_config["RandomSeed"] = random_seed
25192536

@@ -2526,6 +2543,9 @@ def _map_tuning_config(
25262543

25272544
if strategy_config is not None:
25282545
tuning_config["StrategyConfig"] = strategy_config
2546+
2547+
if completion_criteria_config is not None:
2548+
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
25292549
return tuning_config
25302550

25312551
@classmethod

src/sagemaker/tuner.py

Lines changed: 112 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,116 @@ def to_input_req(self):
460460
}
461461

462462

463+
class TuningJobCompletionCriteriaConfig(object):
464+
"""The configuration for a job completion criteria."""
465+
466+
def __init__(
467+
self,
468+
max_number_of_training_jobs_not_improving: int = None,
469+
complete_on_convergence: bool = None,
470+
target_objective_metric_value: float = None,
471+
):
472+
"""Creates a ``TuningJobCompletionCriteriaConfig`` with provided criteria.
473+
474+
Args:
475+
max_number_of_training_jobs_not_improving (int): The number of training jobs that do not
476+
improve the best objective after which tuning job will stop.
477+
complete_on_convergence (bool): A flag to stop your hyperparameter tuning job if
478+
automatic model tuning (AMT) has detected that your model has converged as evaluated
479+
against your objective function.
480+
target_objective_metric_value (float): The value of the objective metric.
481+
"""
482+
483+
self.max_number_of_training_jobs_not_improving = max_number_of_training_jobs_not_improving
484+
self.complete_on_convergence = complete_on_convergence
485+
self.target_objective_metric_value = target_objective_metric_value
486+
487+
@classmethod
488+
def from_job_desc(cls, completion_criteria_config):
489+
"""Creates a ``TuningJobCompletionCriteriaConfig`` from a configuration response.
490+
491+
This is the completion criteria configuration from the DescribeTuningJob response.
492+
Args:
493+
completion_criteria_config (dict): The expected format of the
494+
``completion_criteria_config`` contains three first-class fields
495+
496+
Returns:
497+
sagemaker.tuner.TuningJobCompletionCriteriaConfig: De-serialized instance of
498+
TuningJobCompletionCriteriaConfig containing the completion criteria.
499+
"""
500+
complete_on_convergence = None
501+
if CONVERGENCE_DETECTED in completion_criteria_config:
502+
if completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED]:
503+
complete_on_convergence = bool(
504+
completion_criteria_config[CONVERGENCE_DETECTED][
505+
COMPLETE_ON_CONVERGENCE_DETECTED
506+
]
507+
== "Enabled"
508+
)
509+
510+
max_number_of_training_jobs_not_improving = None
511+
if BEST_OBJECTIVE_NOT_IMPROVING in completion_criteria_config:
512+
if completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
513+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
514+
]:
515+
max_number_of_training_jobs_not_improving = completion_criteria_config[
516+
BEST_OBJECTIVE_NOT_IMPROVING
517+
][MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING]
518+
519+
target_objective_metric_value = None
520+
if TARGET_OBJECTIVE_METRIC_VALUE in completion_criteria_config:
521+
target_objective_metric_value = completion_criteria_config[
522+
TARGET_OBJECTIVE_METRIC_VALUE
523+
]
524+
525+
return cls(
526+
max_number_of_training_jobs_not_improving=max_number_of_training_jobs_not_improving,
527+
complete_on_convergence=complete_on_convergence,
528+
target_objective_metric_value=target_objective_metric_value,
529+
)
530+
531+
def to_input_req(self):
532+
"""Converts the ``self`` instance to the desired input request format.
533+
534+
Examples:
535+
>>> completion_criteria_config = TuningJobCompletionCriteriaConfig(
536+
max_number_of_training_jobs_not_improving=5
537+
complete_on_convergence = True,
538+
target_objective_metric_value = 0.42
539+
)
540+
>>> completion_criteria_config.to_input_req()
541+
{
542+
"BestObjectiveNotImproving": {
543+
"MaxNumberOfTrainingJobsNotImproving":5
544+
},
545+
"ConvergenceDetected": {
546+
"CompleteOnConvergence": "Enabled",
547+
},
548+
"TargetObjectiveMetricValue": 0.42
549+
}
550+
551+
Returns:
552+
dict: Containing the completion criteria configurations.
553+
"""
554+
completion_criteria_config = {}
555+
if self.max_number_of_training_jobs_not_improving is not None:
556+
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
557+
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
558+
] = self.max_number_of_training_jobs_not_improving
559+
560+
if self.target_objective_metric_value is not None:
561+
completion_criteria_config[
562+
TARGET_OBJECTIVE_METRIC_VALUE
563+
] = self.target_objective_metric_value
564+
565+
if self.complete_on_convergence is not None:
566+
completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] = (
567+
"Enabled" if self.complete_on_convergence else "Disabled"
568+
)
569+
570+
return completion_criteria_config
571+
572+
463573
class HyperparameterTuner(object):
464574
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
465575
@@ -559,14 +669,14 @@ def __init__(
559669
self.estimator = None
560670
self.objective_metric_name = None
561671
self._hyperparameter_ranges = None
562-
self.static_hyperparameters = None
563672
self.metric_definitions = None
564673
self.estimator_dict = {estimator_name: estimator}
565674
self.objective_metric_name_dict = {estimator_name: objective_metric_name}
566675
self._hyperparameter_ranges_dict = {estimator_name: hyperparameter_ranges}
567676
self.metric_definitions_dict = (
568677
{estimator_name: metric_definitions} if metric_definitions is not None else {}
569678
)
679+
self.static_hyperparameters = None
570680
else:
571681
self.estimator = estimator
572682
self.objective_metric_name = objective_metric_name
@@ -598,31 +708,6 @@ def __init__(
598708
self.warm_start_config = warm_start_config
599709
self.early_stopping_type = early_stopping_type
600710
self.random_seed = random_seed
601-
self.instance_configs_dict = None
602-
self.instance_configs = None
603-
604-
def override_resource_config(
605-
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
606-
):
607-
"""Override the instance configuration of the estimators used by the tuner.
608-
609-
Args:
610-
instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
611-
The InstanceConfigs to use as an override for the instance configuration
612-
of the estimator. ``None`` will remove the override.
613-
"""
614-
if isinstance(instance_configs, dict):
615-
self._validate_dict_argument(
616-
name="instance_configs",
617-
value=instance_configs,
618-
allowed_keys=list(self.estimator_dict.keys()),
619-
)
620-
self.instance_configs_dict = instance_configs
621-
else:
622-
self.instance_configs = instance_configs
623-
if self.estimator_dict is not None and self.estimator_dict.keys():
624-
estimator_names = list(self.estimator_dict.keys())
625-
self.instance_configs_dict = {estimator_names[0]: instance_configs}
626711

627712
def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
628713
"""Prepare the tuner instance for tuning (fit)."""
@@ -691,6 +776,7 @@ def _prepare_job_name_for_tuning(self, job_name=None):
691776

692777
def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False):
693778
"""Prepare static hyperparameters for all estimators before tuning."""
779+
self.static_hyperparameters = None
694780
if self.estimator is not None:
695781
self.static_hyperparameters = self._prepare_static_hyperparameters(
696782
self.estimator, self._hyperparameter_ranges, include_cls_metadata
@@ -1918,7 +2004,6 @@ def _get_tuner_args(cls, tuner, inputs):
19182004
estimator=tuner.estimator,
19192005
static_hyperparameters=tuner.static_hyperparameters,
19202006
metric_definitions=tuner.metric_definitions,
1921-
instance_configs=tuner.instance_configs,
19222007
)
19232008

19242009
if tuner.estimator_dict is not None:
@@ -1932,44 +2017,12 @@ def _get_tuner_args(cls, tuner, inputs):
19322017
tuner.objective_type,
19332018
tuner.objective_metric_name_dict[estimator_name],
19342019
tuner.hyperparameter_ranges_dict()[estimator_name],
1935-
tuner.instance_configs_dict.get(estimator_name, None)
1936-
if tuner.instance_configs_dict is not None
1937-
else None,
19382020
)
19392021
for estimator_name in sorted(tuner.estimator_dict.keys())
19402022
]
19412023

19422024
return tuner_args
19432025

1944-
@staticmethod
1945-
def _prepare_hp_resource_config(
1946-
instance_configs: List[InstanceConfig],
1947-
instance_count: int,
1948-
instance_type: str,
1949-
volume_size: int,
1950-
volume_kms_key: str,
1951-
):
1952-
"""Placeholder hpo resource config for one estimator of the tuner."""
1953-
resource_config = {}
1954-
if volume_kms_key is not None:
1955-
resource_config["VolumeKmsKeyId"] = volume_kms_key
1956-
1957-
if instance_configs is None:
1958-
resource_config["InstanceCount"] = instance_count
1959-
resource_config["InstanceType"] = instance_type
1960-
resource_config["VolumeSizeInGB"] = volume_size
1961-
else:
1962-
resource_config["InstanceConfigs"] = _TuningJob._prepare_instance_configs(
1963-
instance_configs
1964-
)
1965-
1966-
return resource_config
1967-
1968-
@staticmethod
1969-
def _prepare_instance_configs(instance_configs: List[InstanceConfig]):
1970-
"""Prepare instance config for create tuning request."""
1971-
return [config.to_input_req() for config in instance_configs]
1972-
19732026
@staticmethod
19742027
def _prepare_training_config(
19752028
inputs,
@@ -1980,20 +2033,10 @@ def _prepare_training_config(
19802033
objective_type=None,
19812034
objective_metric_name=None,
19822035
parameter_ranges=None,
1983-
instance_configs=None,
19842036
):
19852037
"""Prepare training config for one estimator."""
19862038
training_config = _Job._load_config(inputs, estimator)
19872039

1988-
del training_config["resource_config"]
1989-
training_config["hpo_resource_config"] = _TuningJob._prepare_hp_resource_config(
1990-
instance_configs,
1991-
estimator.instance_count,
1992-
estimator.instance_type,
1993-
estimator.volume_size,
1994-
estimator.volume_kms_key,
1995-
)
1996-
19972040
training_config["input_mode"] = estimator.input_mode
19982041
training_config["metric_definitions"] = metric_definitions
19992042

tests/unit/test_tuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,12 +663,17 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
663663
assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
664664
assert tuner.max_jobs == 1
665665
assert tuner.max_parallel_jobs == 1
666+
assert tuner.max_runtime_in_seconds == 1
666667
assert tuner.metric_definitions == METRIC_DEFINITIONS
667668
assert tuner.strategy == "Bayesian"
668669
assert tuner.objective_type == "Minimize"
669670
assert tuner.early_stopping_type == "Off"
670671
assert tuner.random_seed == 0
671672

673+
assert tuner.completion_criteria_config.complete_on_convergence is True
674+
assert tuner.completion_criteria_config.target_objective_metric_value == 0.42
675+
assert tuner.completion_criteria_config.max_number_of_training_jobs_not_improving == 5
676+
672677
assert isinstance(tuner.estimator, PCA)
673678
assert tuner.estimator.role == ROLE
674679
assert tuner.estimator.instance_count == 1

tests/unit/tuner_test_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@
9797

9898
TUNING_JOB_DETAILS = {
9999
"HyperParameterTuningJobConfig": {
100-
"ResourceLimits": {"MaxParallelTrainingJobs": 1, "MaxNumberOfTrainingJobs": 1},
100+
"ResourceLimits": {
101+
"MaxParallelTrainingJobs": 1,
102+
"MaxNumberOfTrainingJobs": 1,
103+
"MaxRuntimeInSeconds": 1,
104+
},
101105
"HyperParameterTuningJobObjective": {
102106
"MetricName": OBJECTIVE_METRIC_NAME,
103107
"Type": "Minimize",
@@ -117,6 +121,11 @@
117121
},
118122
"TrainingJobEarlyStoppingType": "Off",
119123
"RandomSeed": 0,
124+
"TuningJobCompletionCriteria": {
125+
"BestObjectiveNotImproving": {"MaxNumberOfTrainingJobsNotImproving": 5},
126+
"ConvergenceDetected": {"CompleteOnConvergence": "Enabled"},
127+
"TargetObjectiveMetricValue": 0.42,
128+
},
120129
},
121130
"HyperParameterTuningJobName": JOB_NAME,
122131
"TrainingJobDefinition": {

0 commit comments

Comments
 (0)