Skip to content

fix: the Hyperband support fix for the HPO #3516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,7 @@ def tune( # noqa: C901
stop_condition,
tags,
warm_start_config,
strategy_config=None,
enable_network_isolation=False,
image_uri=None,
algorithm_arn=None,
Expand All @@ -2136,6 +2137,8 @@ def tune( # noqa: C901
Args:
job_name (str): Name of the tuning job being created.
strategy (str): Strategy to be used for hyperparameter estimations.
strategy_config (dict): A configuration for the hyperparameter tuning
job optimisation strategy.
objective_type (str): The type of the objective metric for evaluating training jobs.
This value can be either 'Minimize' or 'Maximize'.
objective_metric_name (str): Name of the metric for evaluating training jobs.
Expand Down Expand Up @@ -2220,6 +2223,7 @@ def tune( # noqa: C901
objective_metric_name=objective_metric_name,
parameter_ranges=parameter_ranges,
early_stopping_type=early_stopping_type,
strategy_config=strategy_config,
),
"TrainingJobDefinition": self._map_training_config(
static_hyperparameters=static_hyperparameters,
Expand Down Expand Up @@ -2375,6 +2379,7 @@ def _map_tuning_config(
objective_type=None,
objective_metric_name=None,
parameter_ranges=None,
strategy_config=None,
):
"""Construct tuning job configuration dictionary.

Expand All @@ -2392,6 +2397,8 @@ def _map_tuning_config(
objective_metric_name (str): Name of the metric for evaluating training jobs.
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
be one of three types: Continuous, Integer, or Categorical.
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
strategy.

Returns:
A dictionary of tuning job configuration. For format details, please refer to
Expand All @@ -2415,6 +2422,8 @@ def _map_tuning_config(
if parameter_ranges is not None:
tuning_config["ParameterRanges"] = parameter_ranges

if strategy_config is not None:
tuning_config["StrategyConfig"] = strategy_config
return tuning_config

@classmethod
Expand Down
14 changes: 7 additions & 7 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config):

Returns:
sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of
HyperbandStrategyConfig containing the max_resource and min_resource provided as part of
``hyperband_strategy_config``.
``HyperbandStrategyConfig`` containing the max_resource
and min_resource provided as part of ``hyperband_strategy_config``.
"""
return cls(
min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE],
Expand All @@ -306,7 +306,7 @@ def to_input_req(self):

Returns:
dict: Containing the "MaxResource" and
"MinResource" as the first class fields.
"MinResource" as the first class fields.
"""
return {
HYPERBAND_MIN_RESOURCE: self.min_resource,
Expand All @@ -330,7 +330,7 @@ def __init__(

Args:
hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration
for the object that specifies the Hyperband strategy.
for the object that specifies the Hyperband strategy.
This parameter is only supported for the Hyperband selection for Strategy within
the HyperParameterTuningJobConfig.
"""
Expand Down Expand Up @@ -461,7 +461,7 @@ def __init__(
``WarmStartConfig`` object that has been initialized with the
configuration defining the nature of warm start tuning job.
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
tuning job optimisation strategy.
tuning job optimisation strategy.
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
enabled for the job. Can be either 'Auto' or 'Off' (default:
'Off'). If set to 'Off', early stopping will not be attempted.
Expand Down Expand Up @@ -1569,7 +1569,7 @@ def create(
strategy (str): Strategy to be used for hyperparameter estimations
(default: 'Bayesian').
strategy_config (dict): The configuration for a training job launched by a
hyperparameter tuning job.
hyperparameter tuning job.
objective_type (str): The type of the objective metric for evaluating training jobs.
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
Expand Down Expand Up @@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs):
}

if tuner.strategy_config is not None:
tuning_config["strategy_config"] = tuner.strategy_config
tuning_config["strategy_config"] = tuner.strategy_config.to_input_req()

if tuner.objective_metric_name is not None:
tuning_config["objective_type"] = tuner.objective_type
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session):
],
}

SAMPLE_HYPERBAND_STRATEGY_CONFIG = {
"HyperbandStrategyConfig": {
"MinResource": 1,
"MaxResource": 10,
}
}


@pytest.mark.parametrize(
"warm_start_type, parents",
Expand Down Expand Up @@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags):
)


def test_tune_with_strategy_config(sagemaker_session):
def assert_create_tuning_job_request(**kwrags):
assert (
kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
"MinResource"
]
== SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"]
)
assert (
kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
"MaxResource"
]
== SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"]
)

sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = (
assert_create_tuning_job_request
)
sagemaker_session.tune(
job_name="dummy-tuning-1",
strategy="Bayesian",
objective_type="Maximize",
objective_metric_name="val-score",
max_jobs=100,
max_parallel_jobs=5,
parameter_ranges=SAMPLE_PARAM_RANGES,
static_hyperparameters=STATIC_HPs,
image_uri="dummy-image-1",
input_mode="File",
metric_definitions=SAMPLE_METRIC_DEF,
role=EXPANDED_ROLE,
input_config=SAMPLE_INPUT,
output_config=SAMPLE_OUTPUT,
resource_config=RESOURCE_CONFIG,
stop_condition=SAMPLE_STOPPING_CONDITION,
tags=None,
warm_start_config=None,
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
)


def test_tune_with_encryption_flag(sagemaker_session):
def assert_create_tuning_job_request(**kwrags):
assert (
Expand Down