@@ -383,6 +383,37 @@ def to_input_req(self):
383
383
}
384
384
385
385
386
+ class InstanceConfig :
387
+ """Instance configuration for training jobs started by hyperparameter tuning.
388
+
389
+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390
+ These resources include compute instances and storage volumes to use in model training jobs
391
+ launched by hyperparameter tuning jobs.
392
+ """
393
+
394
+ def __init__ (
395
+ self ,
396
+ instance_count : Union [int , PipelineVariable ] = None ,
397
+ instance_type : Union [str , PipelineVariable ] = None ,
398
+ volume_size : Union [int , PipelineVariable ] = 30 ,
399
+ ):
400
+ """Creates a ``InstanceConfig`` instance.
401
+
402
+ It takes instance configuration information for training
403
+ jobs that are created as the result of a hyperparameter tuning job.
404
+ Args:
405
+ * instance_count (str or PipelineVariable): The number of compute instances of type
406
+ InstanceType to use. For distributed training, select a value greater than 1.
407
+ * instance_type (str or PipelineVariable):
408
+ The instance type used to run hyperparameter optimization tuning jobs.
409
+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
410
+ processed for hyperparameter optimization
411
+ """
412
+ self .instance_count = instance_count
413
+ self .instance_type = instance_type
414
+ self .volume_size = volume_size
415
+
416
+
386
417
class HyperparameterTuner (object ):
387
418
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388
419
@@ -419,7 +450,6 @@ def __init__(
419
450
420
451
It takes an estimator to obtain configuration information for training
421
452
jobs that are created as the result of a hyperparameter tuning job.
422
-
423
453
Args:
424
454
estimator (sagemaker.estimator.EstimatorBase): An estimator object
425
455
that has been initialized with the desired configuration. There
@@ -489,6 +519,7 @@ def __init__(
489
519
self .metric_definitions_dict = (
490
520
{estimator_name : metric_definitions } if metric_definitions is not None else {}
491
521
)
522
+ self .instance_configs_dict = {}
492
523
self .static_hyperparameters = None
493
524
else :
494
525
self .estimator = estimator
@@ -500,6 +531,7 @@ def __init__(
500
531
self ._hyperparameter_ranges_dict = None
501
532
self .metric_definitions_dict = None
502
533
self .static_hyperparameters_dict = None
534
+ self .instance_configs_dict = None
503
535
504
536
self ._validate_parameter_ranges (estimator , hyperparameter_ranges )
505
537
@@ -521,6 +553,30 @@ def __init__(
521
553
self .warm_start_config = warm_start_config
522
554
self .early_stopping_type = early_stopping_type
523
555
self .random_seed = random_seed
556
+ self .instance_configs = None
557
+
558
+ def override_resource_config (
559
+ self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
560
+ ):
561
+ """Override the instance configuration of the estimators used by the tuner.
562
+
563
+ Args:
564
+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
565
+ The InstanceConfigs to use as an override for the instance configuration
566
+ of the estimator. ``None`` will remove the override.
567
+ """
568
+ if isinstance (instance_configs , dict ):
569
+ self ._validate_dict_argument (
570
+ name = "instance_configs" ,
571
+ value = instance_configs ,
572
+ allowed_keys = list (self .estimator_dict .keys ()),
573
+ )
574
+ self .instance_configs_dict = instance_configs
575
+ else :
576
+ self .instance_configs = instance_configs
577
+ if self .estimator_dict and self .estimator_dict .keys ():
578
+ estimator_names = list (self .estimator_dict .keys ())
579
+ self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
524
580
525
581
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
526
582
"""Prepare the tuner instance for tuning (fit)."""
@@ -1817,6 +1873,7 @@ def _get_tuner_args(cls, tuner, inputs):
1817
1873
estimator = tuner .estimator ,
1818
1874
static_hyperparameters = tuner .static_hyperparameters ,
1819
1875
metric_definitions = tuner .metric_definitions ,
1876
+ instance_configs = tuner .instance_configs ,
1820
1877
)
1821
1878
1822
1879
if tuner .estimator_dict is not None :
@@ -1830,12 +1887,49 @@ def _get_tuner_args(cls, tuner, inputs):
1830
1887
tuner .objective_type ,
1831
1888
tuner .objective_metric_name_dict [estimator_name ],
1832
1889
tuner .hyperparameter_ranges_dict ()[estimator_name ],
1890
+ tuner .instance_configs_dict .get (estimator_name , None ),
1833
1891
)
1834
1892
for estimator_name in sorted (tuner .estimator_dict .keys ())
1835
1893
]
1836
1894
1837
1895
return tuner_args
1838
1896
1897
+ @staticmethod
1898
+ def _prepare_hp_resource_config (
1899
+ instance_configs : List [InstanceConfig ],
1900
+ instance_count : int ,
1901
+ instance_type : str ,
1902
+ volume_size : int ,
1903
+ volume_kms_key : str ,
1904
+ ):
1905
+ """Placeholder hpo resource config for one estimator of the tuner."""
1906
+ resource_config = {}
1907
+ if volume_kms_key is not None :
1908
+ resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1909
+
1910
+ if instance_configs is None :
1911
+ resource_config ["InstanceCount" ] = instance_count
1912
+ resource_config ["InstanceType" ] = instance_type
1913
+ resource_config ["VolumeSizeInGB" ] = volume_size
1914
+ else :
1915
+ resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1916
+ instance_configs
1917
+ )
1918
+
1919
+ return resource_config
1920
+
1921
+ @staticmethod
1922
+ def _prepare_instance_configs (instance_configs ):
1923
+ """Prepare instance config for create tuning request."""
1924
+ return [
1925
+ {
1926
+ "InstanceCount" : config .instance_count ,
1927
+ "InstanceType" : config .instance_type ,
1928
+ "VolumeSizeInGB" : config .volume_size ,
1929
+ }
1930
+ for config in instance_configs
1931
+ ]
1932
+
1839
1933
@staticmethod
1840
1934
def _prepare_training_config (
1841
1935
inputs ,
@@ -1846,10 +1940,20 @@ def _prepare_training_config(
1846
1940
objective_type = None ,
1847
1941
objective_metric_name = None ,
1848
1942
parameter_ranges = None ,
1943
+ instance_configs = None ,
1849
1944
):
1850
1945
"""Prepare training config for one estimator."""
1851
1946
training_config = _Job ._load_config (inputs , estimator )
1852
1947
1948
+ del training_config ["resource_config" ]
1949
+ training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1950
+ instance_configs ,
1951
+ estimator .instance_count ,
1952
+ estimator .instance_type ,
1953
+ estimator .volume_size ,
1954
+ estimator .volume_kms_key ,
1955
+ )
1956
+
1853
1957
training_config ["input_mode" ] = estimator .input_mode
1854
1958
training_config ["metric_definitions" ] = metric_definitions
1855
1959
0 commit comments