Skip to content

Commit 6aa85f4

Browse files
repushkoAnton Repushko
authored and
Namrata Madan
committed
feature: support the GridSearch strategy for hyperparameter optimization (aws#3439)
Co-authored-by: Anton Repushko <[email protected]>
1 parent c9d3cf4 commit 6aa85f4

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

src/sagemaker/tuner.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
from sagemaker.estimator import Framework, EstimatorBase
3434
from sagemaker.inputs import TrainingInput, FileSystemInput
3535
from sagemaker.job import _Job
36-
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
36+
from sagemaker.jumpstart.utils import (
37+
add_jumpstart_tags,
38+
get_jumpstart_base_name_if_jumpstart_model,
39+
)
3740
from sagemaker.parameter import (
3841
CategoricalParameter,
3942
ContinuousParameter,
@@ -44,7 +47,12 @@
4447
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
4548

4649
from sagemaker.session import Session
47-
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, to_string
50+
from sagemaker.utils import (
51+
base_from_name,
52+
base_name_from_image,
53+
name_from_base,
54+
to_string,
55+
)
4856

4957
AMAZON_ESTIMATOR_MODULE = "sagemaker"
5058
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -60,6 +68,7 @@
6068
HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
6169
PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
6270
WARM_START_TYPE = "WarmStartType"
71+
GRID_SEARCH = "GridSearch"
6372

6473
logger = logging.getLogger(__name__)
6574

@@ -165,7 +174,8 @@ def from_job_desc(cls, warm_start_config):
165174
parents.append(parent[HYPERPARAMETER_TUNING_JOB_NAME])
166175

167176
return cls(
168-
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]), parents=parents
177+
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]),
178+
parents=parents,
169179
)
170180

171181
def to_input_req(self):
@@ -219,7 +229,7 @@ def __init__(
219229
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
220230
strategy: Union[str, PipelineVariable] = "Bayesian",
221231
objective_type: Union[str, PipelineVariable] = "Maximize",
222-
max_jobs: Union[int, PipelineVariable] = 1,
232+
max_jobs: Union[int, PipelineVariable] = None,
223233
max_parallel_jobs: Union[int, PipelineVariable] = 1,
224234
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
225235
base_tuning_job_name: Optional[str] = None,
@@ -258,7 +268,8 @@ def __init__(
258268
evaluating training jobs. This value can be either 'Minimize' or
259269
'Maximize' (default: 'Maximize').
260270
max_jobs (int or PipelineVariable): Maximum total number of training jobs to start for
261-
the hyperparameter tuning job (default: 1).
271+
the hyperparameter tuning job. The default value is unspecified fot the GridSearch
272+
strategy and the default value is 1 for all others strategies (default: None).
262273
max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
263274
start (default: 1).
264275
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
@@ -311,7 +322,12 @@ def __init__(
311322

312323
self.strategy = strategy
313324
self.objective_type = objective_type
325+
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
326+
# For all other strategies for the backward compatibility we keep
327+
# the default value as 1 (previous default value).
314328
self.max_jobs = max_jobs
329+
if max_jobs is None and strategy is not GRID_SEARCH:
330+
self.max_jobs = 1
315331
self.max_parallel_jobs = max_parallel_jobs
316332

317333
self.tags = tags
@@ -373,7 +389,8 @@ def _prepare_job_name_for_tuning(self, job_name=None):
373389
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
374390
)
375391
base_name = base_name_from_image(
376-
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
392+
estimator.training_image_uri(),
393+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
377394
)
378395

379396
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
@@ -434,7 +451,15 @@ def _prepare_static_hyperparameters(
434451
def fit(
435452
self,
436453
inputs: Optional[
437-
Union[str, Dict, List, TrainingInput, FileSystemInput, RecordSet, FileSystemRecordSet]
454+
Union[
455+
str,
456+
Dict,
457+
List,
458+
TrainingInput,
459+
FileSystemInput,
460+
RecordSet,
461+
FileSystemRecordSet,
462+
]
438463
] = None,
439464
job_name: Optional[str] = None,
440465
include_cls_metadata: Union[bool, Dict[str, bool]] = False,
@@ -524,7 +549,9 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
524549
allowed_keys=estimator_names,
525550
)
526551
self._validate_dict_argument(
527-
name="estimator_kwargs", value=estimator_kwargs, allowed_keys=estimator_names
552+
name="estimator_kwargs",
553+
value=estimator_kwargs,
554+
allowed_keys=estimator_names,
528555
)
529556

530557
for (estimator_name, estimator) in self.estimator_dict.items():
@@ -546,7 +573,13 @@ def _prepare_estimator_for_tuning(cls, estimator, inputs, job_name, **kwargs):
546573
estimator._prepare_for_training(job_name)
547574

548575
@classmethod
549-
def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estimator_cls=None):
576+
def attach(
577+
cls,
578+
tuning_job_name,
579+
sagemaker_session=None,
580+
job_details=None,
581+
estimator_cls=None,
582+
):
550583
"""Attach to an existing hyperparameter tuning job.
551584
552585
Create a HyperparameterTuner bound to an existing hyperparameter
@@ -959,7 +992,8 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
959992

960993
# Default to the BYO estimator
961994
return getattr(
962-
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME
995+
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE),
996+
cls.DEFAULT_ESTIMATOR_CLS_NAME,
963997
)
964998

965999
@classmethod
@@ -1151,7 +1185,10 @@ def _validate_parameter_ranges(self, estimator, hyperparameter_ranges):
11511185

11521186
def _validate_parameter_range(self, value_hp, parameter_range):
11531187
"""Placeholder docstring"""
1154-
for (parameter_range_key, parameter_range_value) in parameter_range.__dict__.items():
1188+
for (
1189+
parameter_range_key,
1190+
parameter_range_value,
1191+
) in parameter_range.__dict__.items():
11551192
if parameter_range_key == "scaling_type":
11561193
continue
11571194

@@ -1301,7 +1338,7 @@ def create(
13011338
base_tuning_job_name=None,
13021339
strategy="Bayesian",
13031340
objective_type="Maximize",
1304-
max_jobs=1,
1341+
max_jobs=None,
13051342
max_parallel_jobs=1,
13061343
tags=None,
13071344
warm_start_config=None,
@@ -1351,7 +1388,8 @@ def create(
13511388
objective_type (str): The type of the objective metric for evaluating training jobs.
13521389
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
13531390
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
1354-
tuning job (default: 1).
1391+
tuning job. The default value is unspecified fot the GridSearch strategy
1392+
and the value is 1 for all others strategies (default: None).
13551393
max_parallel_jobs (int): Maximum number of parallel training jobs to start
13561394
(default: 1).
13571395
tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,

tests/unit/test_tuner.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,3 +1774,22 @@ def test_no_tags_prefixes_non_jumpstart_models(
17741774
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == []
17751775

17761776
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == []
1777+
1778+
1779+
def test_create_tuner_with_grid_search_strategy():
1780+
tuner = HyperparameterTuner.create(
1781+
base_tuning_job_name=BASE_JOB_NAME,
1782+
estimator_dict={ESTIMATOR_NAME: ESTIMATOR},
1783+
objective_metric_name_dict={ESTIMATOR_NAME: OBJECTIVE_METRIC_NAME},
1784+
hyperparameter_ranges_dict={ESTIMATOR_NAME: HYPERPARAMETER_RANGES},
1785+
metric_definitions_dict={ESTIMATOR_NAME: METRIC_DEFINITIONS},
1786+
strategy="GridSearch",
1787+
objective_type="Minimize",
1788+
max_parallel_jobs=1,
1789+
tags=TAGS,
1790+
warm_start_config=WARM_START_CONFIG,
1791+
early_stopping_type="Auto",
1792+
)
1793+
1794+
assert tuner is not None
1795+
assert tuner.max_jobs is None

0 commit comments

Comments
 (0)