Skip to content

feature: support of the intelligent stopping in the tuner #3652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,9 @@ def tune( # noqa: C901
stop_condition,
tags,
warm_start_config,
max_runtime_in_seconds=None,
strategy_config=None,
completion_criteria_config=None,
enable_network_isolation=False,
image_uri=None,
algorithm_arn=None,
Expand Down Expand Up @@ -2256,6 +2258,10 @@ def tune( # noqa: C901
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
warm_start_config (dict): Configuration defining the type of warm start and
other required configurations.
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
that a training job launched by a hyperparameter tuning job can run.
completion_criteria_config (sagemaker.tuner.TuningJobCompletionCriteriaConfig): A
configuration for the completion criteria.
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
attempted. If set to 'Auto', early stopping of some training jobs may happen, but
Expand Down Expand Up @@ -2311,12 +2317,14 @@ def tune( # noqa: C901
strategy=strategy,
max_jobs=max_jobs,
max_parallel_jobs=max_parallel_jobs,
max_runtime_in_seconds=max_runtime_in_seconds,
objective_type=objective_type,
objective_metric_name=objective_metric_name,
parameter_ranges=parameter_ranges,
early_stopping_type=early_stopping_type,
random_seed=random_seed,
strategy_config=strategy_config,
completion_criteria_config=completion_criteria_config,
),
"TrainingJobDefinition": self._map_training_config(
static_hyperparameters=static_hyperparameters,
Expand Down Expand Up @@ -2470,12 +2478,14 @@ def _map_tuning_config(
strategy,
max_jobs,
max_parallel_jobs,
max_runtime_in_seconds=None,
early_stopping_type="Off",
objective_type=None,
objective_metric_name=None,
parameter_ranges=None,
random_seed=None,
strategy_config=None,
completion_criteria_config=None,
):
"""Construct tuning job configuration dictionary.

Expand All @@ -2484,6 +2494,8 @@ def _map_tuning_config(
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
tuning job.
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
max_runtime_in_seconds (int or PipelineVariable): The maximum time in seconds
that a training job launched by a hyperparameter tuning job can run.
early_stopping_type (str): Specifies whether early stopping is enabled for the job.
Can be either 'Auto' or 'Off'. If set to 'Off', early stopping will not be
attempted. If set to 'Auto', early stopping of some training jobs may happen,
Expand All @@ -2498,6 +2510,8 @@ def _map_tuning_config(
produce more consistent configurations for the same tuning job.
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
strategy.
completion_criteria_config (dict): A configuration
for the completion criteria.

Returns:
A dictionary of tuning job configuration. For format details, please refer to
Expand All @@ -2514,6 +2528,9 @@ def _map_tuning_config(
"TrainingJobEarlyStoppingType": early_stopping_type,
}

if max_runtime_in_seconds is not None:
tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds

if random_seed is not None:
tuning_config["RandomSeed"] = random_seed

Expand All @@ -2526,6 +2543,9 @@ def _map_tuning_config(

if strategy_config is not None:
tuning_config["StrategyConfig"] = strategy_config

if completion_criteria_config is not None:
tuning_config["TuningJobCompletionCriteria"] = completion_criteria_config
return tuning_config

@classmethod
Expand Down
Loading