Skip to content

Commit bb85268

Browse files
repushkoAnton Repushko
authored andcommitted
feature: support of the intelligent stopping in the tuner (aws#3652)
Co-authored-by: Anton Repushko <[email protected]>
1 parent fb59669 commit bb85268

File tree

4 files changed

+190
-4
lines changed

4 files changed

+190
-4
lines changed

src/sagemaker/session.py

+20
Original file line numberDiff line numberDiff line change
@@ -2189,7 +2189,9 @@ def tune( # noqa: C901
21892189
stop_condition,
21902190
tags,
21912191
warm_start_config,
2192+
max_runtime_in_seconds=None,
21922193
strategy_config=None,
2194+
completion_criteria_config=None,
21932195
enable_network_isolation=False,
21942196
image_uri=None,
21952197
algorithm_arn=None,
@@ -2256,6 +2258,10 @@ def tune( # noqa: C901
22562258
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
22572259
warm_start_config (dict): Configuration defining the type of warm start and
22582260
other required configurations.
2261+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2262+
that a training job launched by a hyperparameter tuning job can run.
2263+
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
2264+
configuration for the completion criteria.
22592265
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
22602266
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
22612267
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
@@ -2311,12 +2317,14 @@ def tune( # noqa: C901
23112317
strategy=strategy,
23122318
max_jobs=max_jobs,
23132319
max_parallel_jobs=max_parallel_jobs,
2320+
max_runtime_in_seconds=max_runtime_in_seconds,
23142321
objective_type=objective_type,
23152322
objective_metric_name=objective_metric_name,
23162323
parameter_ranges=parameter_ranges,
23172324
early_stopping_type=early_stopping_type,
23182325
random_seed=random_seed,
23192326
strategy_config=strategy_config,
2327+
completion_criteria_config=completion_criteria_config,
23202328
),
23212329
"TrainingJobDefinition": self._map_training_config(
23222330
static_hyperparameters=static_hyperparameters,
@@ -2470,12 +2478,14 @@ def _map_tuning_config(
24702478
strategy,
24712479
max_jobs,
24722480
max_parallel_jobs,
2481+
max_runtime_in_seconds=None,
24732482
early_stopping_type="Off",
24742483
objective_type=None,
24752484
objective_metric_name=None,
24762485
parameter_ranges=None,
24772486
random_seed=None,
24782487
strategy_config=None,
2488+
completion_criteria_config=None,
24792489
):
24802490
"""Construct tuning job configuration dictionary.
24812491
@@ -2484,6 +2494,8 @@ def _map_tuning_config(
24842494
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
24852495
tuning job.
24862496
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
2497+
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
2498+
that a training job launched by a hyperparameter tuning job can run.
24872499
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
24882500
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
24892501
attempted. If set to 'Auto', early stopping of some training jobs may happen,
@@ -2498,6 +2510,8 @@ def _map_tuning_config(
24982510
produce more consistent configurations for the same tuning job.
24992511
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
25002512
strategy.
2513+
completion_criteria_config (dict): A configuration
2514+
for the completion criteria.
25012515
25022516
Returns:
25032517
A dictionary of tuning job configuration. For format details, please refer to
@@ -2514,6 +2528,9 @@ def _map_tuning_config(
25142528
"TrainingJobEarlyStoppingType": early_stopping_type,
25152529
}
25162530

2531+
if max_runtime_in_seconds is not None:
2532+
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds
2533+
25172534
if random_seed is not None:
25182535
tuning_config["RandomSeed"] = random_seed
25192536

@@ -2526,6 +2543,9 @@ def _map_tuning_config(
25262543

25272544
if strategy_config is not None:
25282545
tuning_config["StrategyConfig"] = strategy_config
2546+
2547+
if completion_criteria_config is not None:
2548+
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
25292549
return tuning_config
25302550

25312551
@classmethod

0 commit comments

Comments
 (0)