Skip to content

Commit 01fa581

Browse files
Anton Repushkotrajanikant
Anton Repushko
authored andcommitted
feature: support for flexible instance types in the HPO
1 parent 4d95b05 commit 01fa581

File tree

5 files changed

+280
-6
lines changed

5 files changed

+280
-6
lines changed

src/sagemaker/session.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,7 @@ def tune( # noqa: C901
21502150
checkpoint_s3_uri=None,
21512151
checkpoint_local_path=None,
21522152
random_seed=None,
2153+
hpo_resource_config=None,
21532154
):
21542155
"""Create an Amazon SageMaker hyperparameter tuning job.
21552156
@@ -2233,6 +2234,25 @@ def tune( # noqa: C901
22332234
random_seed (int): An initial value used to initialize a pseudo-random number generator.
22342235
Setting a random seed will make the hyperparameter tuning search strategies to
22352236
produce more consistent configurations for the same tuning job. (default: ``None``).
2237+
hpo_resource_config (dict): The configuration for the hyperparameter tuning resources,
2238+
including the compute instances and storage volumes, used for training jobs launched
2239+
by the tuning job. You must specify either
2240+
instance_configs or instance_count + instance_type + volume_size.
2241+
* instance_count (int): Number of EC2 instances to use for training.
2242+
The key in resource_config is 'InstanceCount'.
2243+
* instance_type (str): Type of EC2 instance to use for training, for example,
2244+
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
2245+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
2246+
processed for hyperparameter optimisation
2247+
* instance_configs (List[InstanceConfig]): A list containing the configuration(s)
2248+
for one or more resources for processing hyperparameter jobs. These resources
2249+
include compute instances and storage volumes to use in model training jobs.
2250+
* volume_kms_key_id: A key used by AWS Key Management Service to encrypt data on
2251+
the storage volume attached to the compute instances used to run the training job.
2252+
You can use either of the following formats to specify a key.
2253+
* KMS Key ID: ``1234abcd-12ab-34cd-56ef-1234567890ab``
2254+
* Amazon Resource Name (ARN) of a KMS key:
2255+
``arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab``
22362256
"""
22372257

22382258
tune_request = {
@@ -2258,6 +2278,7 @@ def tune( # noqa: C901
22582278
input_config=input_config,
22592279
output_config=output_config,
22602280
resource_config=resource_config,
2281+
hpo_resource_config=hpo_resource_config,
22612282
vpc_config=vpc_config,
22622283
stop_condition=stop_condition,
22632284
enable_network_isolation=enable_network_isolation,
@@ -2491,9 +2512,10 @@ def _map_training_config(
24912512
input_mode,
24922513
role,
24932514
output_config,
2494-
resource_config,
24952515
stop_condition,
24962516
input_config=None,
2517+
resource_config=None,
2518+
hpo_resource_config=None,
24972519
metric_definitions=None,
24982520
image_uri=None,
24992521
algorithm_arn=None,
@@ -2568,13 +2590,18 @@ def _map_training_config(
25682590
TrainingJobDefinition as described in
25692591
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
25702592
"""
2593+
if hpo_resource_config:
2594+
resource_config_map = {"HyperParameterTuningResourceConfig": hpo_resource_config}
2595+
else:
2596+
resource_config_map = {"ResourceConfig": resource_config}
25712597

25722598
training_job_definition = {
25732599
"StaticHyperParameters": static_hyperparameters,
25742600
"RoleArn": role,
25752601
"OutputDataConfig": output_config,
25762602
"ResourceConfig": resource_config,
25772603
"StoppingCondition": stop_condition,
2604+
**resource_config_map,
25782605
}
25792606

25802607
algorithm_spec = {"TrainingInputMode": input_mode}

src/sagemaker/tuner.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,37 @@ def to_input_req(self):
383383
}
384384

385385

386+
class InstanceConfig:
387+
"""Instance configuration for training jobs started by hyperparameter tuning.
388+
389+
Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390+
These resources include compute instances and storage volumes to use in model training jobs
391+
launched by hyperparameter tuning jobs.
392+
"""
393+
394+
def __init__(
395+
self,
396+
instance_count: Union[int, PipelineVariable] = None,
397+
instance_type: Union[str, PipelineVariable] = None,
398+
volume_size: Union[int, PipelineVariable] = 30,
399+
):
400+
"""Creates a ``InstanceConfig`` instance.
401+
402+
It takes instance configuration information for training
403+
jobs that are created as the result of a hyperparameter tuning job.
404+
Args:
405+
* instance_count (str or PipelineVariable): The number of compute instances of type
406+
InstanceType to use. For distributed training, select a value greater than 1.
407+
* instance_type (str or PipelineVariable):
408+
The instance type used to run hyperparameter optimization tuning jobs.
409+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
410+
processed for hyperparameter optimization
411+
"""
412+
self.instance_count = instance_count
413+
self.instance_type = instance_type
414+
self.volume_size = volume_size
415+
416+
386417
class HyperparameterTuner(object):
387418
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388419
@@ -419,7 +450,6 @@ def __init__(
419450
420451
It takes an estimator to obtain configuration information for training
421452
jobs that are created as the result of a hyperparameter tuning job.
422-
423453
Args:
424454
estimator (sagemaker.estimator.EstimatorBase): An estimator object
425455
that has been initialized with the desired configuration. There
@@ -489,6 +519,7 @@ def __init__(
489519
self.metric_definitions_dict = (
490520
{estimator_name: metric_definitions} if metric_definitions is not None else {}
491521
)
522+
self.instance_configs_dict = {}
492523
self.static_hyperparameters = None
493524
else:
494525
self.estimator = estimator
@@ -500,6 +531,7 @@ def __init__(
500531
self._hyperparameter_ranges_dict = None
501532
self.metric_definitions_dict = None
502533
self.static_hyperparameters_dict = None
534+
self.instance_configs_dict = None
503535

504536
self._validate_parameter_ranges(estimator, hyperparameter_ranges)
505537

@@ -521,6 +553,30 @@ def __init__(
521553
self.warm_start_config = warm_start_config
522554
self.early_stopping_type = early_stopping_type
523555
self.random_seed = random_seed
556+
self.instance_configs = None
557+
558+
def override_resource_config(
559+
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
560+
):
561+
"""Override the instance configuration of the estimators used by the tuner.
562+
563+
Args:
564+
instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
565+
The InstanceConfigs to use as an override for the instance configuration
566+
of the estimator. ``None`` will remove the override.
567+
"""
568+
if isinstance(instance_configs, dict):
569+
self._validate_dict_argument(
570+
name="instance_configs",
571+
value=instance_configs,
572+
allowed_keys=list(self.estimator_dict.keys()),
573+
)
574+
self.instance_configs_dict = instance_configs
575+
else:
576+
self.instance_configs = instance_configs
577+
if self.estimator_dict and self.estimator_dict.keys():
578+
estimator_names = list(self.estimator_dict.keys())
579+
self.instance_configs_dict = {estimator_names[0]: instance_configs}
524580

525581
def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
526582
"""Prepare the tuner instance for tuning (fit)."""
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
18171873
estimator=tuner.estimator,
18181874
static_hyperparameters=tuner.static_hyperparameters,
18191875
metric_definitions=tuner.metric_definitions,
1876+
instance_configs=tuner.instance_configs,
18201877
)
18211878

18221879
if tuner.estimator_dict is not None:
@@ -1830,12 +1887,49 @@ def _get_tuner_args(cls, tuner, inputs):
18301887
tuner.objective_type,
18311888
tuner.objective_metric_name_dict[estimator_name],
18321889
tuner.hyperparameter_ranges_dict()[estimator_name],
1890+
tuner.instance_configs_dict.get(estimator_name, None),
18331891
)
18341892
for estimator_name in sorted(tuner.estimator_dict.keys())
18351893
]
18361894

18371895
return tuner_args
18381896

1897+
@staticmethod
1898+
def _prepare_hp_resource_config(
1899+
instance_configs: List[InstanceConfig],
1900+
instance_count: int,
1901+
instance_type: str,
1902+
volume_size: int,
1903+
volume_kms_key: str,
1904+
):
1905+
"""Placeholder hpo resource config for one estimator of the tuner."""
1906+
resource_config = {}
1907+
if volume_kms_key is not None:
1908+
resource_config["VolumeKmsKeyId"] = volume_kms_key
1909+
1910+
if instance_configs is None:
1911+
resource_config["InstanceCount"] = instance_count
1912+
resource_config["InstanceType"] = instance_type
1913+
resource_config["VolumeSizeInGB"] = volume_size
1914+
else:
1915+
resource_config["InstanceConfigs"] = _TuningJob._prepare_instance_configs(
1916+
instance_configs
1917+
)
1918+
1919+
return resource_config
1920+
1921+
@staticmethod
1922+
def _prepare_instance_configs(instance_configs):
1923+
"""Prepare instance config for create tuning request."""
1924+
return [
1925+
{
1926+
"InstanceCount": config.instance_count,
1927+
"InstanceType": config.instance_type,
1928+
"VolumeSizeInGB": config.volume_size,
1929+
}
1930+
for config in instance_configs
1931+
]
1932+
18391933
@staticmethod
18401934
def _prepare_training_config(
18411935
inputs,
@@ -1846,10 +1940,20 @@ def _prepare_training_config(
18461940
objective_type=None,
18471941
objective_metric_name=None,
18481942
parameter_ranges=None,
1943+
instance_configs=None,
18491944
):
18501945
"""Prepare training config for one estimator."""
18511946
training_config = _Job._load_config(inputs, estimator)
18521947

1948+
del training_config["resource_config"]
1949+
training_config["hpo_resource_config"] = _TuningJob._prepare_hp_resource_config(
1950+
instance_configs,
1951+
estimator.instance_count,
1952+
estimator.instance_type,
1953+
estimator.volume_size,
1954+
estimator.volume_kms_key,
1955+
)
1956+
18531957
training_config["input_mode"] = estimator.input_mode
18541958
training_config["metric_definitions"] = metric_definitions
18551959

tests/integ/test_tuner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ContinuousParameter,
3636
CategoricalParameter,
3737
HyperparameterTuner,
38+
InstanceConfig,
3839
WarmStartConfig,
3940
WarmStartTypes,
4041
create_transfer_learning_tuner,
@@ -97,6 +98,7 @@ def _tune_and_deploy(
9798
job_name=None,
9899
warm_start_config=None,
99100
early_stopping_type="Off",
101+
instance_configs=None,
100102
):
101103
tuner = _tune(
102104
kmeans_estimator,
@@ -105,6 +107,7 @@ def _tune_and_deploy(
105107
warm_start_config=warm_start_config,
106108
job_name=job_name,
107109
early_stopping_type=early_stopping_type,
110+
instance_configs=instance_configs,
108111
)
109112
_deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type, cpu_instance_type)
110113

@@ -134,6 +137,7 @@ def _tune(
134137
max_jobs=2,
135138
max_parallel_jobs=2,
136139
early_stopping_type="Off",
140+
instance_configs=None,
137141
):
138142
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
139143

@@ -148,7 +152,7 @@ def _tune(
148152
warm_start_config=warm_start_config,
149153
early_stopping_type=early_stopping_type,
150154
)
151-
155+
tuner.override_resource_config(instance_configs=instance_configs)
152156
records = kmeans_estimator.record_set(kmeans_train_set[0][:100])
153157
test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel="test")
154158

@@ -173,6 +177,25 @@ def test_tuning_kmeans(
173177
)
174178

175179

180+
@pytest.mark.release
181+
def test_tuning_kmeans_with_instance_configs(
182+
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges, cpu_instance_type
183+
):
184+
job_name = unique_name_from_base("tst-fit")
185+
_tune_and_deploy(
186+
kmeans_estimator,
187+
kmeans_train_set,
188+
sagemaker_session,
189+
cpu_instance_type,
190+
hyperparameter_ranges=hyperparameter_ranges,
191+
job_name=job_name,
192+
instance_configs=[
193+
InstanceConfig(instance_count=1, instance_type="ml.m4.2xlarge", volume_size=30),
194+
InstanceConfig(instance_count=1, instance_type="ml.m4.xlarge", volume_size=30),
195+
],
196+
)
197+
198+
176199
def test_tuning_kmeans_identical_dataset_algorithm_tuner_raw(
177200
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges
178201
):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def test_single_algo_tuning_step(sagemaker_session):
11331133
},
11341134
"RoleArn": "DummyRole",
11351135
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1136-
"ResourceConfig": {
1136+
"HyperParameterTuningResourceConfig": {
11371137
"InstanceCount": 1,
11381138
"InstanceType": "ml.c5.4xlarge",
11391139
"VolumeSizeInGB": 30,
@@ -1285,7 +1285,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
12851285
},
12861286
"RoleArn": "DummyRole",
12871287
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1288-
"ResourceConfig": {
1288+
"HyperParameterTuningResourceConfig": {
12891289
"InstanceCount": {"Get": "Parameters.InstanceCount"},
12901290
"InstanceType": "ml.c5.4xlarge",
12911291
"VolumeSizeInGB": 30,
@@ -1352,7 +1352,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
13521352
},
13531353
"RoleArn": "DummyRole",
13541354
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1355-
"ResourceConfig": {
1355+
"HyperParameterTuningResourceConfig": {
13561356
"InstanceCount": {"Get": "Parameters.InstanceCount"},
13571357
"InstanceType": "ml.c5.4xlarge",
13581358
"VolumeSizeInGB": 30,

0 commit comments

Comments
 (0)