Skip to content

Commit 2d35567

Browse files
repushkoAnton Repushko
authored andcommitted
fix: the Hyperband support fix for the HPO (aws#3516)
Co-authored-by: Anton Repushko <[email protected]>
1 parent ceafb18 commit 2d35567

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

src/sagemaker/session.py

+9
Original file line numberDiff line numberDiff line change
@@ -2121,6 +2121,7 @@ def tune( # noqa: C901
21212121
stop_condition,
21222122
tags,
21232123
warm_start_config,
2124+
strategy_config=None,
21242125
enable_network_isolation=False,
21252126
image_uri=None,
21262127
algorithm_arn=None,
@@ -2136,6 +2137,8 @@ def tune( # noqa: C901
21362137
Args:
21372138
job_name (str): Name of the tuning job being created.
21382139
strategy (str): Strategy to be used for hyperparameter estimations.
2140+
strategy_config (dict): A configuration for the hyperparameter tuning
2141+
job optimisation strategy.
21392142
objective_type (str): The type of the objective metric for evaluating training jobs.
21402143
This value can be either 'Minimize' or 'Maximize'.
21412144
objective_metric_name (str): Name of the metric for evaluating training jobs.
@@ -2220,6 +2223,7 @@ def tune( # noqa: C901
22202223
objective_metric_name=objective_metric_name,
22212224
parameter_ranges=parameter_ranges,
22222225
early_stopping_type=early_stopping_type,
2226+
strategy_config=strategy_config,
22232227
),
22242228
"TrainingJobDefinition": self._map_training_config(
22252229
static_hyperparameters=static_hyperparameters,
@@ -2375,6 +2379,7 @@ def _map_tuning_config(
23752379
objective_type=None,
23762380
objective_metric_name=None,
23772381
parameter_ranges=None,
2382+
strategy_config=None,
23782383
):
23792384
"""Construct tuning job configuration dictionary.
23802385
@@ -2392,6 +2397,8 @@ def _map_tuning_config(
23922397
objective_metric_name (str): Name of the metric for evaluating training jobs.
23932398
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
23942399
be one of three types: Continuous, Integer, or Categorical.
2400+
strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
2401+
strategy.
23952402
23962403
Returns:
23972404
A dictionary of tuning job configuration. For format details, please refer to
@@ -2415,6 +2422,8 @@ def _map_tuning_config(
24152422
if parameter_ranges is not None:
24162423
tuning_config["ParameterRanges"] = parameter_ranges
24172424

2425+
if strategy_config is not None:
2426+
tuning_config["StrategyConfig"] = strategy_config
24182427
return tuning_config
24192428

24202429
@classmethod

src/sagemaker/tuner.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config):
282282
283283
Returns:
284284
sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of
285-
HyperbandStrategyConfig containing the max_resource and min_resource provided as part of
286-
``hyperband_strategy_config``.
285+
``HyperbandStrategyConfig`` containing the max_resource
286+
and min_resource provided as part of ``hyperband_strategy_config``.
287287
"""
288288
return cls(
289289
min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE],
@@ -306,7 +306,7 @@ def to_input_req(self):
306306
307307
Returns:
308308
dict: Containing the "MaxResource" and
309-
"MinResource" as the first class fields.
309+
"MinResource" as the first class fields.
310310
"""
311311
return {
312312
HYPERBAND_MIN_RESOURCE: self.min_resource,
@@ -330,7 +330,7 @@ def __init__(
330330
331331
Args:
332332
hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration
333-
for the object that specifies the Hyperband strategy.
333+
for the object that specifies the Hyperband strategy.
334334
This parameter is only supported for the Hyperband selection for Strategy within
335335
the HyperParameterTuningJobConfig.
336336
"""
@@ -461,7 +461,7 @@ def __init__(
461461
``WarmStartConfig`` object that has been initialized with the
462462
configuration defining the nature of warm start tuning job.
463463
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
464-
tuning job optimisation strategy.
464+
tuning job optimisation strategy.
465465
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
466466
enabled for the job. Can be either 'Auto' or 'Off' (default:
467467
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -1569,7 +1569,7 @@ def create(
15691569
strategy (str): Strategy to be used for hyperparameter estimations
15701570
(default: 'Bayesian').
15711571
strategy_config (dict): The configuration for a training job launched by a
1572-
hyperparameter tuning job.
1572+
hyperparameter tuning job.
15731573
objective_type (str): The type of the objective metric for evaluating training jobs.
15741574
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
15751575
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
@@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs):
17761776
}
17771777

17781778
if tuner.strategy_config is not None:
1779-
tuning_config["strategy_config"] = tuner.strategy_config
1779+
tuning_config["strategy_config"] = tuner.strategy_config.to_input_req()
17801780

17811781
if tuner.objective_metric_name is not None:
17821782
tuning_config["objective_type"] = tuner.objective_type

tests/unit/test_session.py

+48
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session):
941941
],
942942
}
943943

944+
SAMPLE_HYPERBAND_STRATEGY_CONFIG = {
945+
"HyperbandStrategyConfig": {
946+
"MinResource": 1,
947+
"MaxResource": 10,
948+
}
949+
}
950+
944951

945952
@pytest.mark.parametrize(
946953
"warm_start_type, parents",
@@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags):
11671174
)
11681175

11691176

1177+
def test_tune_with_strategy_config(sagemaker_session):
1178+
def assert_create_tuning_job_request(**kwrags):
1179+
assert (
1180+
kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
1181+
"MinResource"
1182+
]
1183+
== SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"]
1184+
)
1185+
assert (
1186+
kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
1187+
"MaxResource"
1188+
]
1189+
== SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"]
1190+
)
1191+
1192+
sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = (
1193+
assert_create_tuning_job_request
1194+
)
1195+
sagemaker_session.tune(
1196+
job_name="dummy-tuning-1",
1197+
strategy="Bayesian",
1198+
objective_type="Maximize",
1199+
objective_metric_name="val-score",
1200+
max_jobs=100,
1201+
max_parallel_jobs=5,
1202+
parameter_ranges=SAMPLE_PARAM_RANGES,
1203+
static_hyperparameters=STATIC_HPs,
1204+
image_uri="dummy-image-1",
1205+
input_mode="File",
1206+
metric_definitions=SAMPLE_METRIC_DEF,
1207+
role=EXPANDED_ROLE,
1208+
input_config=SAMPLE_INPUT,
1209+
output_config=SAMPLE_OUTPUT,
1210+
resource_config=RESOURCE_CONFIG,
1211+
stop_condition=SAMPLE_STOPPING_CONDITION,
1212+
tags=None,
1213+
warm_start_config=None,
1214+
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
1215+
)
1216+
1217+
11701218
def test_tune_with_encryption_flag(sagemaker_session):
11711219
def assert_create_tuning_job_request(**kwrags):
11721220
assert (

0 commit comments

Comments
 (0)