From d80bad07df61bb0c883b8cb78193c80757a97343 Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Thu, 8 Dec 2022 21:02:50 +0100 Subject: [PATCH] fix: the Hyperband support fix for the HPO --- src/sagemaker/session.py | 9 +++++++ src/sagemaker/tuner.py | 14 +++++------ tests/unit/test_session.py | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..3fc4fc1256 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2121,6 +2121,7 @@ def tune( # noqa: C901 stop_condition, tags, warm_start_config, + strategy_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, @@ -2136,6 +2137,8 @@ def tune( # noqa: C901 Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. + strategy_config (dict): A configuration for the hyperparameter tuning + job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. @@ -2220,6 +2223,7 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, @@ -2375,6 +2379,7 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2392,6 +2397,8 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + strategy_config (dict): A configuration for the hyperparameter tuning job optimisation + strategy. Returns: A dictionary of tuning job configuration. For format details, please refer to @@ -2415,6 +2422,8 @@ def _map_tuning_config( if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges + if strategy_config is not None: + tuning_config["StrategyConfig"] = strategy_config return tuning_config @classmethod diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 52b9d81d0d..9a694cbec9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config): Returns: sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of - HyperbandStrategyConfig containing the max_resource and min_resource provided as part of - ``hyperband_strategy_config``. + ``HyperbandStrategyConfig`` containing the max_resource + and min_resource provided as part of ``hyperband_strategy_config``. """ return cls( min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE], @@ -306,7 +306,7 @@ def to_input_req(self): Returns: dict: Containing the "MaxResource" and - "MinResource" as the first class fields. + "MinResource" as the first class fields. """ return { HYPERBAND_MIN_RESOURCE: self.min_resource, @@ -330,7 +330,7 @@ def __init__( Args: hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration - for the object that specifies the Hyperband strategy. + for the object that specifies the Hyperband strategy. This parameter is only supported for the Hyperband selection for Strategy within the HyperParameterTuningJobConfig. """ @@ -461,7 +461,7 @@ def __init__( ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start tuning job. strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter" - tuning job optimisation strategy. + tuning job optimisation strategy. early_stopping_type (str or PipelineVariable): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. @@ -1569,7 +1569,7 @@ def create( strategy (str): Strategy to be used for hyperparameter estimations (default: 'Bayesian'). strategy_config (dict): The configuration for a training job launched by a - hyperparameter tuning job. + hyperparameter tuning job. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). max_jobs (int): Maximum total number of training jobs to start for the hyperparameter @@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs): } if tuner.strategy_config is not None: - tuning_config["strategy_config"] = tuner.strategy_config + tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() if tuner.objective_metric_name is not None: tuning_config["objective_type"] = tuner.objective_type diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..bf81283177 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session): ], } +SAMPLE_HYPERBAND_STRATEGY_CONFIG = { + "HyperbandStrategyConfig": { + "MinResource": 1, + "MaxResource": 10, + } +} + @pytest.mark.parametrize( "warm_start_type, parents", @@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags): ) +def test_tune_with_strategy_config(sagemaker_session): + def assert_create_tuning_job_request(**kwrags): + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MinResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"] + ) + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MaxResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"] + ) + + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image_uri="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG, + ) + + def test_tune_with_encryption_flag(sagemaker_session): def assert_create_tuning_job_request(**kwrags): assert (