Skip to content

Commit be76bb1

Browse files
author
Anton Repushko
committed
feature: support the GridSearch strategy for hyperparameter optimization
1 parent ecb4ac2 commit be76bb1

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

src/sagemaker/tuner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
6161
PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
6262
WARM_START_TYPE = "WarmStartType"
63+
GRID_SEARCH = "GridSearch"
6364

6465
logger = logging.getLogger(__name__)
6566

@@ -219,7 +220,7 @@ def __init__(
219220
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
220221
strategy: Union[str, PipelineVariable] = "Bayesian",
221222
objective_type: Union[str, PipelineVariable] = "Maximize",
222-
max_jobs: Union[int, PipelineVariable] = 1,
223+
max_jobs: Union[int, PipelineVariable] = None,
223224
max_parallel_jobs: Union[int, PipelineVariable] = 1,
224225
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
225226
base_tuning_job_name: Optional[str] = None,
@@ -258,7 +259,8 @@ def __init__(
258259
evaluating training jobs. This value can be either 'Minimize' or
259260
'Maximize' (default: 'Maximize').
260261
max_jobs (int or PipelineVariable): Maximum total number of training jobs to start for
261-
the hyperparameter tuning job (default: 1).
262+
the hyperparameter tuning job. The default value is unspecified fot the GridSearch strategy
263+
and the default value is 1 for all others strategies (default: None).
262264
max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
263265
start (default: 1).
264266
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
@@ -311,7 +313,11 @@ def __init__(
311313

312314
self.strategy = strategy
313315
self.objective_type = objective_type
316+
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
317+
# For all other strategies for the backward compatibility we keep the default value as 1 (previous default value).
314318
self.max_jobs = max_jobs
319+
if max_jobs is None and strategy is not GRID_SEARCH:
320+
self.max_jobs = 1
315321
self.max_parallel_jobs = max_parallel_jobs
316322

317323
self.tags = tags
@@ -1301,7 +1307,7 @@ def create(
13011307
base_tuning_job_name=None,
13021308
strategy="Bayesian",
13031309
objective_type="Maximize",
1304-
max_jobs=1,
1310+
max_jobs=None,
13051311
max_parallel_jobs=1,
13061312
tags=None,
13071313
warm_start_config=None,
@@ -1351,7 +1357,8 @@ def create(
13511357
objective_type (str): The type of the objective metric for evaluating training jobs.
13521358
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
13531359
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
1354-
tuning job (default: 1).
1360+
tuning job. The default value is unspecified fot the GridSearch strategy
1361+
and the value is 1 for all others strategies (default: None).
13551362
max_parallel_jobs (int): Maximum number of parallel training jobs to start
13561363
(default: 1).
13571364
tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,

tests/unit/test_tuner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,3 +1774,21 @@ def test_no_tags_prefixes_non_jumpstart_models(
17741774
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == []
17751775

17761776
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == []
1777+
1778+
def test_create_tuner_with_grid_search_strategy():
1779+
tuner = HyperparameterTuner.create(
1780+
base_tuning_job_name=BASE_JOB_NAME,
1781+
estimator_dict={ESTIMATOR_NAME: ESTIMATOR},
1782+
objective_metric_name_dict={ESTIMATOR_NAME: OBJECTIVE_METRIC_NAME},
1783+
hyperparameter_ranges_dict={ESTIMATOR_NAME: HYPERPARAMETER_RANGES},
1784+
metric_definitions_dict={ESTIMATOR_NAME: METRIC_DEFINITIONS},
1785+
strategy="GridSearch",
1786+
objective_type="Minimize",
1787+
max_parallel_jobs=1,
1788+
tags=TAGS,
1789+
warm_start_config=WARM_START_CONFIG,
1790+
early_stopping_type="Auto",
1791+
)
1792+
1793+
assert tuner is not None
1794+
assert tuner.max_jobs is None

0 commit comments

Comments
 (0)