Skip to content

feature: support the GridSearch strategy for hyperparameter optimization #3439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.inputs import TrainingInput, FileSystemInput
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
from sagemaker.jumpstart.utils import (
add_jumpstart_tags,
get_jumpstart_base_name_if_jumpstart_model,
)
from sagemaker.parameter import (
CategoricalParameter,
ContinuousParameter,
Expand All @@ -44,7 +47,12 @@
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

from sagemaker.session import Session
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, to_string
from sagemaker.utils import (
base_from_name,
base_name_from_image,
name_from_base,
to_string,
)

AMAZON_ESTIMATOR_MODULE = "sagemaker"
AMAZON_ESTIMATOR_CLS_NAMES = {
Expand All @@ -60,6 +68,7 @@
HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
WARM_START_TYPE = "WarmStartType"
GRID_SEARCH = "GridSearch"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,7 +174,8 @@ def from_job_desc(cls, warm_start_config):
parents.append(parent[HYPERPARAMETER_TUNING_JOB_NAME])

return cls(
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]), parents=parents
warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]),
parents=parents,
)

def to_input_req(self):
Expand Down Expand Up @@ -219,7 +229,7 @@ def __init__(
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
strategy: Union[str, PipelineVariable] = "Bayesian",
objective_type: Union[str, PipelineVariable] = "Maximize",
max_jobs: Union[int, PipelineVariable] = 1,
max_jobs: Union[int, PipelineVariable] = None,
max_parallel_jobs: Union[int, PipelineVariable] = 1,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
base_tuning_job_name: Optional[str] = None,
Expand Down Expand Up @@ -258,7 +268,8 @@ def __init__(
evaluating training jobs. This value can be either 'Minimize' or
'Maximize' (default: 'Maximize').
max_jobs (int or PipelineVariable): Maximum total number of training jobs to start for
the hyperparameter tuning job (default: 1).
the hyperparameter tuning job. The default value is unspecified fot the GridSearch
strategy and the default value is 1 for all others strategies (default: None).
max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
start (default: 1).
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
Expand Down Expand Up @@ -311,7 +322,12 @@ def __init__(

self.strategy = strategy
self.objective_type = objective_type
# For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
# For all other strategies for the backward compatibility we keep
# the default value as 1 (previous default value).
self.max_jobs = max_jobs
if max_jobs is None and strategy is not GRID_SEARCH:
self.max_jobs = 1
self.max_parallel_jobs = max_parallel_jobs

self.tags = tags
Expand Down Expand Up @@ -373,7 +389,8 @@ def _prepare_job_name_for_tuning(self, job_name=None):
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
)
base_name = base_name_from_image(
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
estimator.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
)

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
Expand Down Expand Up @@ -434,7 +451,15 @@ def _prepare_static_hyperparameters(
def fit(
self,
inputs: Optional[
Union[str, Dict, List, TrainingInput, FileSystemInput, RecordSet, FileSystemRecordSet]
Union[
str,
Dict,
List,
TrainingInput,
FileSystemInput,
RecordSet,
FileSystemRecordSet,
]
] = None,
job_name: Optional[str] = None,
include_cls_metadata: Union[bool, Dict[str, bool]] = False,
Expand Down Expand Up @@ -524,7 +549,9 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
allowed_keys=estimator_names,
)
self._validate_dict_argument(
name="estimator_kwargs", value=estimator_kwargs, allowed_keys=estimator_names
name="estimator_kwargs",
value=estimator_kwargs,
allowed_keys=estimator_names,
)

for (estimator_name, estimator) in self.estimator_dict.items():
Expand All @@ -546,7 +573,13 @@ def _prepare_estimator_for_tuning(cls, estimator, inputs, job_name, **kwargs):
estimator._prepare_for_training(job_name)

@classmethod
def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estimator_cls=None):
def attach(
cls,
tuning_job_name,
sagemaker_session=None,
job_details=None,
estimator_cls=None,
):
"""Attach to an existing hyperparameter tuning job.

Create a HyperparameterTuner bound to an existing hyperparameter
Expand Down Expand Up @@ -959,7 +992,8 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):

# Default to the BYO estimator
return getattr(
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME
importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE),
cls.DEFAULT_ESTIMATOR_CLS_NAME,
)

@classmethod
Expand Down Expand Up @@ -1151,7 +1185,10 @@ def _validate_parameter_ranges(self, estimator, hyperparameter_ranges):

def _validate_parameter_range(self, value_hp, parameter_range):
"""Placeholder docstring"""
for (parameter_range_key, parameter_range_value) in parameter_range.__dict__.items():
for (
parameter_range_key,
parameter_range_value,
) in parameter_range.__dict__.items():
if parameter_range_key == "scaling_type":
continue

Expand Down Expand Up @@ -1301,7 +1338,7 @@ def create(
base_tuning_job_name=None,
strategy="Bayesian",
objective_type="Maximize",
max_jobs=1,
max_jobs=None,
max_parallel_jobs=1,
tags=None,
warm_start_config=None,
Expand Down Expand Up @@ -1351,7 +1388,8 @@ def create(
objective_type (str): The type of the objective metric for evaluating training jobs.
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
tuning job (default: 1).
tuning job. The default value is unspecified fot the GridSearch strategy
and the value is 1 for all others strategies (default: None).
max_parallel_jobs (int): Maximum number of parallel training jobs to start
(default: 1).
tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,3 +1774,22 @@ def test_no_tags_prefixes_non_jumpstart_models(
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == []

assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == []


def test_create_tuner_with_grid_search_strategy():
tuner = HyperparameterTuner.create(
base_tuning_job_name=BASE_JOB_NAME,
estimator_dict={ESTIMATOR_NAME: ESTIMATOR},
objective_metric_name_dict={ESTIMATOR_NAME: OBJECTIVE_METRIC_NAME},
hyperparameter_ranges_dict={ESTIMATOR_NAME: HYPERPARAMETER_RANGES},
metric_definitions_dict={ESTIMATOR_NAME: METRIC_DEFINITIONS},
strategy="GridSearch",
objective_type="Minimize",
max_parallel_jobs=1,
tags=TAGS,
warm_start_config=WARM_START_CONFIG,
early_stopping_type="Auto",
)

assert tuner is not None
assert tuner.max_jobs is None