Skip to content

Commit 13d388b

Browse files
author
Anton Repushko
committed
feature: support for flexible instance types in the HPO
1 parent e2f3888 commit 13d388b

File tree

5 files changed

+279
-7
lines changed

5 files changed

+279
-7
lines changed

src/sagemaker/session.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,7 @@ def tune( # noqa: C901
22012201
checkpoint_local_path=None,
22022202
random_seed=None,
22032203
environment=None,
2204+
hpo_resource_config=None,
22042205
):
22052206
"""Create an Amazon SageMaker hyperparameter tuning job.
22062207
@@ -2286,6 +2287,22 @@ def tune( # noqa: C901
22862287
produce more consistent configurations for the same tuning job. (default: ``None``).
22872288
environment (dict[str, str]) : Environment variables to be set for
22882289
use during training jobs (default: ``None``)
2290+
hpo_resource_config (dict): The configuration for the hyperparameter tuning resources,
2291+
including the compute instances and storage volumes, used for training jobs launched
2292+
by the tuning job, where you must specify either
2293+
instance_configs or instance_count + instance_type + volume_size:
2294+
* instance_count (int): Number of EC2 instances to use for training.
2295+
The key in resource_config is 'InstanceCount'.
2296+
* instance_type (str): Type of EC2 instance to use for training, for example,
2297+
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
2298+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
2299+
processed for hyperparameter optimisation
2300+
* instance_configs (List[InstanceConfig]): A list containing the configuration(s)
2301+
for one or more resources for processing hyperparameter jobs. These resources
2302+
include compute instances and storage volumes to use in model training jobs.
2303+
* volume_kms_key_id: The AWS Key Management Service (AWS KMS) key
2304+
that Amazon SageMaker uses to encrypt data on the storage
2305+
volume attached to the ML compute instance(s) that run the training job.
22892306
"""
22902307

22912308
tune_request = {
@@ -2311,6 +2328,7 @@ def tune( # noqa: C901
23112328
input_config=input_config,
23122329
output_config=output_config,
23132330
resource_config=resource_config,
2331+
hpo_resource_config=hpo_resource_config,
23142332
vpc_config=vpc_config,
23152333
stop_condition=stop_condition,
23162334
enable_network_isolation=enable_network_isolation,
@@ -2545,9 +2563,10 @@ def _map_training_config(
25452563
input_mode,
25462564
role,
25472565
output_config,
2548-
resource_config,
25492566
stop_condition,
25502567
input_config=None,
2568+
resource_config=None,
2569+
hpo_resource_config=None,
25512570
metric_definitions=None,
25522571
image_uri=None,
25532572
algorithm_arn=None,
@@ -2625,13 +2644,17 @@ def _map_training_config(
26252644
TrainingJobDefinition as described in
26262645
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
26272646
"""
2647+
if hpo_resource_config is not None:
2648+
resource_config_map = {"HyperParameterTuningResourceConfig": hpo_resource_config}
2649+
else:
2650+
resource_config_map = {"ResourceConfig": resource_config}
26282651

26292652
training_job_definition = {
26302653
"StaticHyperParameters": static_hyperparameters,
26312654
"RoleArn": role,
26322655
"OutputDataConfig": output_config,
2633-
"ResourceConfig": resource_config,
26342656
"StoppingCondition": stop_condition,
2657+
**resource_config_map,
26352658
}
26362659

26372660
algorithm_spec = {"TrainingInputMode": input_mode}

src/sagemaker/tuner.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,38 @@ def to_input_req(self):
383383
}
384384

385385

386+
class InstanceConfig:
387+
"""Instance configuration for training jobs started by hyperparameter tuning.
388+
389+
Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390+
These resources include compute instances and storage volumes to use in model training jobs
391+
launched by hyperparameter tuning jobs.
392+
"""
393+
394+
def __init__(
395+
self,
396+
instance_count: Union[int, PipelineVariable] = None,
397+
instance_type: Union[str, PipelineVariable] = None,
398+
volume_size: Union[int, PipelineVariable] = 30,
399+
):
400+
"""Creates a ``InstanceConfig`` instance.
401+
402+
It takes instance configuration information for training
403+
jobs that are created as the result of a hyperparameter tuning job.
404+
405+
Args:
406+
* instance_count (str or PipelineVariable): The number of compute instances of type
407+
InstanceType to use. For distributed training, select a value greater than 1.
408+
* instance_type (str or PipelineVariable):
409+
The instance type used to run hyperparameter optimization tuning jobs.
410+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
411+
processed for hyperparameter optimization
412+
"""
413+
self.instance_count = instance_count
414+
self.instance_type = instance_type
415+
self.volume_size = volume_size
416+
417+
386418
class HyperparameterTuner(object):
387419
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388420
@@ -489,7 +521,6 @@ def __init__(
489521
self.metric_definitions_dict = (
490522
{estimator_name: metric_definitions} if metric_definitions is not None else {}
491523
)
492-
self.static_hyperparameters = None
493524
else:
494525
self.estimator = estimator
495526
self.objective_metric_name = objective_metric_name
@@ -521,6 +552,31 @@ def __init__(
521552
self.warm_start_config = warm_start_config
522553
self.early_stopping_type = early_stopping_type
523554
self.random_seed = random_seed
555+
self.instance_configs_dict = None
556+
self.instance_configs = None
557+
558+
def override_resource_config(
559+
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
560+
):
561+
"""Override the instance configuration of the estimators used by the tuner.
562+
563+
Args:
564+
instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
565+
The InstanceConfigs to use as an override for the instance configuration
566+
of the estimator. ``None`` will remove the override.
567+
"""
568+
if isinstance(instance_configs, dict):
569+
self._validate_dict_argument(
570+
name="instance_configs",
571+
value=instance_configs,
572+
allowed_keys=list(self.estimator_dict.keys()),
573+
)
574+
self.instance_configs_dict = instance_configs
575+
else:
576+
self.instance_configs = instance_configs
577+
if self.estimator_dict is not None and self.estimator_dict.keys():
578+
estimator_names = list(self.estimator_dict.keys())
579+
self.instance_configs_dict = {estimator_names[0]: instance_configs}
524580

525581
def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
526582
"""Prepare the tuner instance for tuning (fit)."""
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
18171873
estimator=tuner.estimator,
18181874
static_hyperparameters=tuner.static_hyperparameters,
18191875
metric_definitions=tuner.metric_definitions,
1876+
instance_configs=tuner.instance_configs,
18201877
)
18211878

18221879
if tuner.estimator_dict is not None:
@@ -1830,12 +1887,51 @@ def _get_tuner_args(cls, tuner, inputs):
18301887
tuner.objective_type,
18311888
tuner.objective_metric_name_dict[estimator_name],
18321889
tuner.hyperparameter_ranges_dict()[estimator_name],
1890+
tuner.instance_configs_dict.get(estimator_name, None)
1891+
if tuner.instance_configs_dict is not None
1892+
else None,
18331893
)
18341894
for estimator_name in sorted(tuner.estimator_dict.keys())
18351895
]
18361896

18371897
return tuner_args
18381898

1899+
@staticmethod
1900+
def _prepare_hp_resource_config(
1901+
instance_configs: List[InstanceConfig],
1902+
instance_count: int,
1903+
instance_type: str,
1904+
volume_size: int,
1905+
volume_kms_key: str,
1906+
):
1907+
"""Placeholder hpo resource config for one estimator of the tuner."""
1908+
resource_config = {}
1909+
if volume_kms_key is not None:
1910+
resource_config["VolumeKmsKeyId"] = volume_kms_key
1911+
1912+
if instance_configs is None:
1913+
resource_config["InstanceCount"] = instance_count
1914+
resource_config["InstanceType"] = instance_type
1915+
resource_config["VolumeSizeInGB"] = volume_size
1916+
else:
1917+
resource_config["InstanceConfigs"] = _TuningJob._prepare_instance_configs(
1918+
instance_configs
1919+
)
1920+
1921+
return resource_config
1922+
1923+
@staticmethod
1924+
def _prepare_instance_configs(instance_configs: List[InstanceConfig]):
1925+
"""Prepare instance config for create tuning request."""
1926+
return [
1927+
InstanceConfig(
1928+
instance_count=config.instance_count,
1929+
instance_type=config.instance_type,
1930+
volume_size=config.volume_size,
1931+
)
1932+
for config in instance_configs
1933+
]
1934+
18391935
@staticmethod
18401936
def _prepare_training_config(
18411937
inputs,
@@ -1846,10 +1942,20 @@ def _prepare_training_config(
18461942
objective_type=None,
18471943
objective_metric_name=None,
18481944
parameter_ranges=None,
1945+
instance_configs=None,
18491946
):
18501947
"""Prepare training config for one estimator."""
18511948
training_config = _Job._load_config(inputs, estimator)
18521949

1950+
del training_config["resource_config"]
1951+
training_config["hpo_resource_config"] = _TuningJob._prepare_hp_resource_config(
1952+
instance_configs,
1953+
estimator.instance_count,
1954+
estimator.instance_type,
1955+
estimator.volume_size,
1956+
estimator.volume_kms_key,
1957+
)
1958+
18531959
training_config["input_mode"] = estimator.input_mode
18541960
training_config["metric_definitions"] = metric_definitions
18551961

tests/integ/test_tuner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ContinuousParameter,
3636
CategoricalParameter,
3737
HyperparameterTuner,
38+
InstanceConfig,
3839
WarmStartConfig,
3940
WarmStartTypes,
4041
create_transfer_learning_tuner,
@@ -97,6 +98,7 @@ def _tune_and_deploy(
9798
job_name=None,
9899
warm_start_config=None,
99100
early_stopping_type="Off",
101+
instance_configs=None,
100102
):
101103
tuner = _tune(
102104
kmeans_estimator,
@@ -105,6 +107,7 @@ def _tune_and_deploy(
105107
warm_start_config=warm_start_config,
106108
job_name=job_name,
107109
early_stopping_type=early_stopping_type,
110+
instance_configs=instance_configs,
108111
)
109112
_deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type, cpu_instance_type)
110113

@@ -134,6 +137,7 @@ def _tune(
134137
max_jobs=2,
135138
max_parallel_jobs=2,
136139
early_stopping_type="Off",
140+
instance_configs=None,
137141
):
138142
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
139143

@@ -148,7 +152,7 @@ def _tune(
148152
warm_start_config=warm_start_config,
149153
early_stopping_type=early_stopping_type,
150154
)
151-
155+
tuner.override_resource_config(instance_configs=instance_configs)
152156
records = kmeans_estimator.record_set(kmeans_train_set[0][:100])
153157
test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel="test")
154158

@@ -173,6 +177,25 @@ def test_tuning_kmeans(
173177
)
174178

175179

180+
@pytest.mark.release
181+
def test_tuning_kmeans_with_instance_configs(
182+
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges, cpu_instance_type
183+
):
184+
job_name = unique_name_from_base("tst-fit")
185+
_tune_and_deploy(
186+
kmeans_estimator,
187+
kmeans_train_set,
188+
sagemaker_session,
189+
cpu_instance_type,
190+
hyperparameter_ranges=hyperparameter_ranges,
191+
job_name=job_name,
192+
instance_configs=[
193+
InstanceConfig(instance_count=1, instance_type="ml.m4.2xlarge", volume_size=30),
194+
InstanceConfig(instance_count=1, instance_type="ml.m4.xlarge", volume_size=30),
195+
],
196+
)
197+
198+
176199
def test_tuning_kmeans_identical_dataset_algorithm_tuner_raw(
177200
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges
178201
):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def test_single_algo_tuning_step(sagemaker_session):
11331133
},
11341134
"RoleArn": "DummyRole",
11351135
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1136-
"ResourceConfig": {
1136+
"HyperParameterTuningResourceConfig": {
11371137
"InstanceCount": 1,
11381138
"InstanceType": "ml.c5.4xlarge",
11391139
"VolumeSizeInGB": 30,
@@ -1285,7 +1285,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
12851285
},
12861286
"RoleArn": "DummyRole",
12871287
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1288-
"ResourceConfig": {
1288+
"HyperParameterTuningResourceConfig": {
12891289
"InstanceCount": {"Get": "Parameters.InstanceCount"},
12901290
"InstanceType": "ml.c5.4xlarge",
12911291
"VolumeSizeInGB": 30,
@@ -1352,7 +1352,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
13521352
},
13531353
"RoleArn": "DummyRole",
13541354
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1355-
"ResourceConfig": {
1355+
"HyperParameterTuningResourceConfig": {
13561356
"InstanceCount": {"Get": "Parameters.InstanceCount"},
13571357
"InstanceType": "ml.c5.4xlarge",
13581358
"VolumeSizeInGB": 30,

0 commit comments

Comments
 (0)