Skip to content

Commit 6238636

Browse files
author
Anton Repushko
committed
feature: support for flexible instance types in the HPO
1 parent dba1026 commit 6238636

File tree

5 files changed

+318
-8
lines changed

5 files changed

+318
-8
lines changed

src/sagemaker/session.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,7 @@ def tune( # noqa: C901
22012201
checkpoint_local_path=None,
22022202
random_seed=None,
22032203
environment=None,
2204+
hpo_resource_config=None,
22042205
):
22052206
"""Create an Amazon SageMaker hyperparameter tuning job.
22062207
@@ -2286,6 +2287,22 @@ def tune( # noqa: C901
22862287
produce more consistent configurations for the same tuning job. (default: ``None``).
22872288
environment (dict[str, str]) : Environment variables to be set for
22882289
use during training jobs (default: ``None``)
2290+
hpo_resource_config (dict): The configuration for the hyperparameter tuning resources,
2291+
including the compute instances and storage volumes, used for training jobs launched
2292+
by the tuning job, where you must specify either
2293+
instance_configs or instance_count + instance_type + volume_size:
2294+
* instance_count (int): Number of EC2 instances to use for training.
2295+
The key in resource_config is 'InstanceCount'.
2296+
* instance_type (str): Type of EC2 instance to use for training, for example,
2297+
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
2298+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
2299+
processed for hyperparameter optimisation
2300+
* instance_configs (List[InstanceConfig]): A list containing the configuration(s)
2301+
for one or more resources for processing hyperparameter jobs. These resources
2302+
include compute instances and storage volumes to use in model training jobs.
2303+
* volume_kms_key_id: The AWS Key Management Service (AWS KMS) key
2304+
that Amazon SageMaker uses to encrypt data on the storage
2305+
volume attached to the ML compute instance(s) that run the training job.
22892306
"""
22902307

22912308
tune_request = {
@@ -2311,6 +2328,7 @@ def tune( # noqa: C901
23112328
input_config=input_config,
23122329
output_config=output_config,
23132330
resource_config=resource_config,
2331+
hpo_resource_config=hpo_resource_config,
23142332
vpc_config=vpc_config,
23152333
stop_condition=stop_condition,
23162334
enable_network_isolation=enable_network_isolation,
@@ -2545,9 +2563,10 @@ def _map_training_config(
25452563
input_mode,
25462564
role,
25472565
output_config,
2548-
resource_config,
25492566
stop_condition,
25502567
input_config=None,
2568+
resource_config=None,
2569+
hpo_resource_config=None,
25512570
metric_definitions=None,
25522571
image_uri=None,
25532572
algorithm_arn=None,
@@ -2625,13 +2644,17 @@ def _map_training_config(
26252644
TrainingJobDefinition as described in
26262645
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
26272646
"""
2647+
if hpo_resource_config is not None:
2648+
resource_config_map = {"HyperParameterTuningResourceConfig": hpo_resource_config}
2649+
else:
2650+
resource_config_map = {"ResourceConfig": resource_config}
26282651

26292652
training_job_definition = {
26302653
"StaticHyperParameters": static_hyperparameters,
26312654
"RoleArn": role,
26322655
"OutputDataConfig": output_config,
2633-
"ResourceConfig": resource_config,
26342656
"StoppingCondition": stop_condition,
2657+
**resource_config_map,
26352658
}
26362659

26372660
algorithm_spec = {"TrainingInputMode": input_mode}

src/sagemaker/tuner.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,83 @@ def to_input_req(self):
383383
}
384384

385385

386+
class InstanceConfig:
387+
"""Instance configuration for training jobs started by hyperparameter tuning.
388+
389+
Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390+
These resources include compute instances and storage volumes to use in model training jobs
391+
launched by hyperparameter tuning jobs.
392+
"""
393+
394+
def __init__(
395+
self,
396+
instance_count: Union[int, PipelineVariable] = None,
397+
instance_type: Union[str, PipelineVariable] = None,
398+
volume_size: Union[int, PipelineVariable] = 30,
399+
):
400+
"""Creates a ``InstanceConfig`` instance.
401+
402+
It takes instance configuration information for training
403+
jobs that are created as the result of a hyperparameter tuning job.
404+
405+
Args:
406+
* instance_count (str or PipelineVariable): The number of compute instances of type
407+
InstanceType to use. For distributed training, select a value greater than 1.
408+
* instance_type (str or PipelineVariable):
409+
The instance type used to run hyperparameter optimization tuning jobs.
410+
* volume_size (int or PipelineVariable): The volume size in GB of the data to be
411+
processed for hyperparameter optimization
412+
"""
413+
self.instance_count = instance_count
414+
self.instance_type = instance_type
415+
self.volume_size = volume_size
416+
417+
@classmethod
418+
def from_job_desc(cls, instance_config):
419+
"""Creates a ``InstanceConfig`` from an instance configuration response.
420+
421+
This is the instance configuration from the DescribeTuningJob response.
422+
423+
Args:
424+
instance_config (dict): The expected format of the
425+
``instance_config`` contains one first-class field
426+
427+
Returns:
428+
sagemaker.tuner.InstanceConfig: De-serialized instance of
429+
InstanceConfig containing the strategy configuration.
430+
"""
431+
return cls(
432+
instance_count=instance_config["InstanceCount"],
433+
instance_type=instance_config[" InstanceType "],
434+
volume_size=instance_config["VolumeSizeInGB"],
435+
)
436+
437+
def to_input_req(self):
438+
"""Converts the ``self`` instance to the desired input request format.
439+
440+
Examples:
441+
>>> strategy_config = InstanceConfig(
442+
instance_count=1,
443+
instance_type='ml.m4.xlarge',
444+
volume_size=30
445+
)
446+
>>> strategy_config.to_input_req()
447+
{
448+
"InstanceCount":1,
449+
"InstanceType":"ml.m4.xlarge",
450+
"VolumeSizeInGB":30
451+
}
452+
453+
Returns:
454+
dict: Containing the instance configurations.
455+
"""
456+
return {
457+
"InstanceCount": self.instance_count,
458+
"InstanceType": self.instance_type,
459+
"VolumeSizeInGB": self.volume_size,
460+
}
461+
462+
386463
class HyperparameterTuner(object):
387464
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388465
@@ -482,14 +559,14 @@ def __init__(
482559
self.estimator = None
483560
self.objective_metric_name = None
484561
self._hyperparameter_ranges = None
562+
self.static_hyperparameters = None
485563
self.metric_definitions = None
486564
self.estimator_dict = {estimator_name: estimator}
487565
self.objective_metric_name_dict = {estimator_name: objective_metric_name}
488566
self._hyperparameter_ranges_dict = {estimator_name: hyperparameter_ranges}
489567
self.metric_definitions_dict = (
490568
{estimator_name: metric_definitions} if metric_definitions is not None else {}
491569
)
492-
self.static_hyperparameters = None
493570
else:
494571
self.estimator = estimator
495572
self.objective_metric_name = objective_metric_name
@@ -521,6 +598,31 @@ def __init__(
521598
self.warm_start_config = warm_start_config
522599
self.early_stopping_type = early_stopping_type
523600
self.random_seed = random_seed
601+
self.instance_configs_dict = None
602+
self.instance_configs = None
603+
604+
def override_resource_config(
605+
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
606+
):
607+
"""Override the instance configuration of the estimators used by the tuner.
608+
609+
Args:
610+
instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
611+
The InstanceConfigs to use as an override for the instance configuration
612+
of the estimator. ``None`` will remove the override.
613+
"""
614+
if isinstance(instance_configs, dict):
615+
self._validate_dict_argument(
616+
name="instance_configs",
617+
value=instance_configs,
618+
allowed_keys=list(self.estimator_dict.keys()),
619+
)
620+
self.instance_configs_dict = instance_configs
621+
else:
622+
self.instance_configs = instance_configs
623+
if self.estimator_dict is not None and self.estimator_dict.keys():
624+
estimator_names = list(self.estimator_dict.keys())
625+
self.instance_configs_dict = {estimator_names[0]: instance_configs}
524626

525627
def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
526628
"""Prepare the tuner instance for tuning (fit)."""
@@ -589,7 +691,6 @@ def _prepare_job_name_for_tuning(self, job_name=None):
589691

590692
def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False):
591693
"""Prepare static hyperparameters for all estimators before tuning."""
592-
self.static_hyperparameters = None
593694
if self.estimator is not None:
594695
self.static_hyperparameters = self._prepare_static_hyperparameters(
595696
self.estimator, self._hyperparameter_ranges, include_cls_metadata
@@ -1817,6 +1918,7 @@ def _get_tuner_args(cls, tuner, inputs):
18171918
estimator=tuner.estimator,
18181919
static_hyperparameters=tuner.static_hyperparameters,
18191920
metric_definitions=tuner.metric_definitions,
1921+
instance_configs=tuner.instance_configs,
18201922
)
18211923

18221924
if tuner.estimator_dict is not None:
@@ -1830,12 +1932,44 @@ def _get_tuner_args(cls, tuner, inputs):
18301932
tuner.objective_type,
18311933
tuner.objective_metric_name_dict[estimator_name],
18321934
tuner.hyperparameter_ranges_dict()[estimator_name],
1935+
tuner.instance_configs_dict.get(estimator_name, None)
1936+
if tuner.instance_configs_dict is not None
1937+
else None,
18331938
)
18341939
for estimator_name in sorted(tuner.estimator_dict.keys())
18351940
]
18361941

18371942
return tuner_args
18381943

1944+
@staticmethod
1945+
def _prepare_hp_resource_config(
1946+
instance_configs: List[InstanceConfig],
1947+
instance_count: int,
1948+
instance_type: str,
1949+
volume_size: int,
1950+
volume_kms_key: str,
1951+
):
1952+
"""Placeholder hpo resource config for one estimator of the tuner."""
1953+
resource_config = {}
1954+
if volume_kms_key is not None:
1955+
resource_config["VolumeKmsKeyId"] = volume_kms_key
1956+
1957+
if instance_configs is None:
1958+
resource_config["InstanceCount"] = instance_count
1959+
resource_config["InstanceType"] = instance_type
1960+
resource_config["VolumeSizeInGB"] = volume_size
1961+
else:
1962+
resource_config["InstanceConfigs"] = _TuningJob._prepare_instance_configs(
1963+
instance_configs
1964+
)
1965+
1966+
return resource_config
1967+
1968+
@staticmethod
1969+
def _prepare_instance_configs(instance_configs: List[InstanceConfig]):
1970+
"""Prepare instance config for create tuning request."""
1971+
return [config.to_input_req() for config in instance_configs]
1972+
18391973
@staticmethod
18401974
def _prepare_training_config(
18411975
inputs,
@@ -1846,10 +1980,20 @@ def _prepare_training_config(
18461980
objective_type=None,
18471981
objective_metric_name=None,
18481982
parameter_ranges=None,
1983+
instance_configs=None,
18491984
):
18501985
"""Prepare training config for one estimator."""
18511986
training_config = _Job._load_config(inputs, estimator)
18521987

1988+
del training_config["resource_config"]
1989+
training_config["hpo_resource_config"] = _TuningJob._prepare_hp_resource_config(
1990+
instance_configs,
1991+
estimator.instance_count,
1992+
estimator.instance_type,
1993+
estimator.volume_size,
1994+
estimator.volume_kms_key,
1995+
)
1996+
18531997
training_config["input_mode"] = estimator.input_mode
18541998
training_config["metric_definitions"] = metric_definitions
18551999

tests/integ/test_tuner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ContinuousParameter,
3636
CategoricalParameter,
3737
HyperparameterTuner,
38+
InstanceConfig,
3839
WarmStartConfig,
3940
WarmStartTypes,
4041
create_transfer_learning_tuner,
@@ -97,6 +98,7 @@ def _tune_and_deploy(
9798
job_name=None,
9899
warm_start_config=None,
99100
early_stopping_type="Off",
101+
instance_configs=None,
100102
):
101103
tuner = _tune(
102104
kmeans_estimator,
@@ -105,6 +107,7 @@ def _tune_and_deploy(
105107
warm_start_config=warm_start_config,
106108
job_name=job_name,
107109
early_stopping_type=early_stopping_type,
110+
instance_configs=instance_configs,
108111
)
109112
_deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type, cpu_instance_type)
110113

@@ -134,6 +137,7 @@ def _tune(
134137
max_jobs=2,
135138
max_parallel_jobs=2,
136139
early_stopping_type="Off",
140+
instance_configs=None,
137141
):
138142
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
139143

@@ -148,7 +152,7 @@ def _tune(
148152
warm_start_config=warm_start_config,
149153
early_stopping_type=early_stopping_type,
150154
)
151-
155+
tuner.override_resource_config(instance_configs=instance_configs)
152156
records = kmeans_estimator.record_set(kmeans_train_set[0][:100])
153157
test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel="test")
154158

@@ -173,6 +177,25 @@ def test_tuning_kmeans(
173177
)
174178

175179

180+
@pytest.mark.release
181+
def test_tuning_kmeans_with_instance_configs(
182+
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges, cpu_instance_type
183+
):
184+
job_name = unique_name_from_base("tst-fit")
185+
_tune_and_deploy(
186+
kmeans_estimator,
187+
kmeans_train_set,
188+
sagemaker_session,
189+
cpu_instance_type,
190+
hyperparameter_ranges=hyperparameter_ranges,
191+
job_name=job_name,
192+
instance_configs=[
193+
InstanceConfig(instance_count=1, instance_type="ml.m4.2xlarge", volume_size=30),
194+
InstanceConfig(instance_count=1, instance_type="ml.m4.xlarge", volume_size=30),
195+
],
196+
)
197+
198+
176199
def test_tuning_kmeans_identical_dataset_algorithm_tuner_raw(
177200
sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges
178201
):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def test_single_algo_tuning_step(sagemaker_session):
11331133
},
11341134
"RoleArn": "DummyRole",
11351135
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1136-
"ResourceConfig": {
1136+
"HyperParameterTuningResourceConfig": {
11371137
"InstanceCount": 1,
11381138
"InstanceType": "ml.c5.4xlarge",
11391139
"VolumeSizeInGB": 30,
@@ -1285,7 +1285,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
12851285
},
12861286
"RoleArn": "DummyRole",
12871287
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1288-
"ResourceConfig": {
1288+
"HyperParameterTuningResourceConfig": {
12891289
"InstanceCount": {"Get": "Parameters.InstanceCount"},
12901290
"InstanceType": "ml.c5.4xlarge",
12911291
"VolumeSizeInGB": 30,
@@ -1352,7 +1352,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
13521352
},
13531353
"RoleArn": "DummyRole",
13541354
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
1355-
"ResourceConfig": {
1355+
"HyperParameterTuningResourceConfig": {
13561356
"InstanceCount": {"Get": "Parameters.InstanceCount"},
13571357
"InstanceType": "ml.c5.4xlarge",
13581358
"VolumeSizeInGB": 30,

0 commit comments

Comments
 (0)