@@ -383,6 +383,38 @@ def to_input_req(self):
383
383
}
384
384
385
385
386
+ class InstanceConfig :
387
+ """Instance configuration for training jobs started by hyperparameter tuning.
388
+
389
+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390
+ These resources include compute instances and storage volumes to use in model training jobs
391
+ launched by hyperparameter tuning jobs.
392
+ """
393
+
394
+ def __init__ (
395
+ self ,
396
+ instance_count : Union [int , PipelineVariable ] = None ,
397
+ instance_type : Union [str , PipelineVariable ] = None ,
398
+ volume_size : Union [int , PipelineVariable ] = 30 ,
399
+ ):
400
+ """Creates a ``InstanceConfig`` instance.
401
+
402
+ It takes instance configuration information for training
403
+ jobs that are created as the result of a hyperparameter tuning job.
404
+
405
+ Args:
406
+ * instance_count (str or PipelineVariable): The number of compute instances of type
407
+ InstanceType to use. For distributed training, select a value greater than 1.
408
+ * instance_type (str or PipelineVariable):
409
+ The instance type used to run hyperparameter optimization tuning jobs.
410
+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
411
+ processed for hyperparameter optimization
412
+ """
413
+ self .instance_count = instance_count
414
+ self .instance_type = instance_type
415
+ self .volume_size = volume_size
416
+
417
+
386
418
class HyperparameterTuner (object ):
387
419
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388
420
@@ -489,7 +521,6 @@ def __init__(
489
521
self .metric_definitions_dict = (
490
522
{estimator_name : metric_definitions } if metric_definitions is not None else {}
491
523
)
492
- self .static_hyperparameters = None
493
524
else :
494
525
self .estimator = estimator
495
526
self .objective_metric_name = objective_metric_name
@@ -521,6 +552,31 @@ def __init__(
521
552
self .warm_start_config = warm_start_config
522
553
self .early_stopping_type = early_stopping_type
523
554
self .random_seed = random_seed
555
+ self .instance_configs_dict = None
556
+ self .instance_configs = None
557
+
558
+ def override_resource_config (
559
+ self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
560
+ ):
561
+ """Override the instance configuration of the estimators used by the tuner.
562
+
563
+ Args:
564
+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
565
+ The InstanceConfigs to use as an override for the instance configuration
566
+ of the estimator. ``None`` will remove the override.
567
+ """
568
+ if isinstance (instance_configs , dict ):
569
+ self ._validate_dict_argument (
570
+ name = "instance_configs" ,
571
+ value = instance_configs ,
572
+ allowed_keys = list (self .estimator_dict .keys ()),
573
+ )
574
+ self .instance_configs_dict = instance_configs
575
+ else :
576
+ self .instance_configs = instance_configs
577
+ if self .estimator_dict is not None and self .estimator_dict .keys ():
578
+ estimator_names = list (self .estimator_dict .keys ())
579
+ self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
524
580
525
581
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
526
582
"""Prepare the tuner instance for tuning (fit)."""
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
1817
1873
estimator = tuner .estimator ,
1818
1874
static_hyperparameters = tuner .static_hyperparameters ,
1819
1875
metric_definitions = tuner .metric_definitions ,
1876
+ instance_configs = tuner .instance_configs ,
1820
1877
)
1821
1878
1822
1879
if tuner .estimator_dict is not None :
@@ -1830,12 +1887,51 @@ def _get_tuner_args(cls, tuner, inputs):
1830
1887
tuner .objective_type ,
1831
1888
tuner .objective_metric_name_dict [estimator_name ],
1832
1889
tuner .hyperparameter_ranges_dict ()[estimator_name ],
1890
+ tuner .instance_configs_dict .get (estimator_name , None )
1891
+ if tuner .instance_configs_dict is not None
1892
+ else None ,
1833
1893
)
1834
1894
for estimator_name in sorted (tuner .estimator_dict .keys ())
1835
1895
]
1836
1896
1837
1897
return tuner_args
1838
1898
1899
+ @staticmethod
1900
+ def _prepare_hp_resource_config (
1901
+ instance_configs : List [InstanceConfig ],
1902
+ instance_count : int ,
1903
+ instance_type : str ,
1904
+ volume_size : int ,
1905
+ volume_kms_key : str ,
1906
+ ):
1907
+ """Placeholder hpo resource config for one estimator of the tuner."""
1908
+ resource_config = {}
1909
+ if volume_kms_key is not None :
1910
+ resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1911
+
1912
+ if instance_configs is None :
1913
+ resource_config ["InstanceCount" ] = instance_count
1914
+ resource_config ["InstanceType" ] = instance_type
1915
+ resource_config ["VolumeSizeInGB" ] = volume_size
1916
+ else :
1917
+ resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1918
+ instance_configs
1919
+ )
1920
+
1921
+ return resource_config
1922
+
1923
+ @staticmethod
1924
+ def _prepare_instance_configs (instance_configs : List [InstanceConfig ]):
1925
+ """Prepare instance config for create tuning request."""
1926
+ return [
1927
+ InstanceConfig (
1928
+ instance_count = config .instance_count ,
1929
+ instance_type = config .instance_type ,
1930
+ volume_size = config .volume_size ,
1931
+ )
1932
+ for config in instance_configs
1933
+ ]
1934
+
1839
1935
@staticmethod
1840
1936
def _prepare_training_config (
1841
1937
inputs ,
@@ -1846,10 +1942,20 @@ def _prepare_training_config(
1846
1942
objective_type = None ,
1847
1943
objective_metric_name = None ,
1848
1944
parameter_ranges = None ,
1945
+ instance_configs = None ,
1849
1946
):
1850
1947
"""Prepare training config for one estimator."""
1851
1948
training_config = _Job ._load_config (inputs , estimator )
1852
1949
1950
+ del training_config ["resource_config" ]
1951
+ training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1952
+ instance_configs ,
1953
+ estimator .instance_count ,
1954
+ estimator .instance_type ,
1955
+ estimator .volume_size ,
1956
+ estimator .volume_kms_key ,
1957
+ )
1958
+
1853
1959
training_config ["input_mode" ] = estimator .input_mode
1854
1960
training_config ["metric_definitions" ] = metric_definitions
1855
1961
0 commit comments