@@ -383,6 +383,83 @@ def to_input_req(self):
383
383
}
384
384
385
385
386
+ class InstanceConfig :
387
+ """Instance configuration for training jobs started by hyperparameter tuning.
388
+
389
+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390
+ These resources include compute instances and storage volumes to use in model training jobs
391
+ launched by hyperparameter tuning jobs.
392
+ """
393
+
394
+ def __init__ (
395
+ self ,
396
+ instance_count : Union [int , PipelineVariable ] = None ,
397
+ instance_type : Union [str , PipelineVariable ] = None ,
398
+ volume_size : Union [int , PipelineVariable ] = 30 ,
399
+ ):
400
+ """Creates a ``InstanceConfig`` instance.
401
+
402
+ It takes instance configuration information for training
403
+ jobs that are created as the result of a hyperparameter tuning job.
404
+
405
+ Args:
406
+ * instance_count (str or PipelineVariable): The number of compute instances of type
407
+ InstanceType to use. For distributed training, select a value greater than 1.
408
+ * instance_type (str or PipelineVariable):
409
+ The instance type used to run hyperparameter optimization tuning jobs.
410
+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
411
+ processed for hyperparameter optimization
412
+ """
413
+ self .instance_count = instance_count
414
+ self .instance_type = instance_type
415
+ self .volume_size = volume_size
416
+
417
+ @classmethod
418
+ def from_job_desc (cls , instance_config ):
419
+ """Creates a ``InstanceConfig`` from an instance configuration response.
420
+
421
+ This is the instance configuration from the DescribeTuningJob response.
422
+
423
+ Args:
424
+ instance_config (dict): The expected format of the
425
+ ``instance_config`` contains one first-class field
426
+
427
+ Returns:
428
+ sagemaker.tuner.InstanceConfig: De-serialized instance of
429
+ InstanceConfig containing the strategy configuration.
430
+ """
431
+ return cls (
432
+ instance_count = instance_config ["InstanceCount" ],
433
+ instance_type = instance_config [" InstanceType " ],
434
+ volume_size = instance_config ["VolumeSizeInGB" ],
435
+ )
436
+
437
+ def to_input_req (self ):
438
+ """Converts the ``self`` instance to the desired input request format.
439
+
440
+ Examples:
441
+ >>> strategy_config = InstanceConfig(
442
+ instance_count=1,
443
+ instance_type='ml.m4.xlarge',
444
+ volume_size=30
445
+ )
446
+ >>> strategy_config.to_input_req()
447
+ {
448
+ "InstanceCount":1,
449
+ "InstanceType":"ml.m4.xlarge",
450
+ "VolumeSizeInGB":30
451
+ }
452
+
453
+ Returns:
454
+ dict: Containing the instance configurations.
455
+ """
456
+ return {
457
+ "InstanceCount" : self .instance_count ,
458
+ "InstanceType" : self .instance_type ,
459
+ "VolumeSizeInGB" : self .volume_size ,
460
+ }
461
+
462
+
386
463
class HyperparameterTuner (object ):
387
464
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388
465
@@ -482,14 +559,14 @@ def __init__(
482
559
self .estimator = None
483
560
self .objective_metric_name = None
484
561
self ._hyperparameter_ranges = None
562
+ self .static_hyperparameters = None
485
563
self .metric_definitions = None
486
564
self .estimator_dict = {estimator_name : estimator }
487
565
self .objective_metric_name_dict = {estimator_name : objective_metric_name }
488
566
self ._hyperparameter_ranges_dict = {estimator_name : hyperparameter_ranges }
489
567
self .metric_definitions_dict = (
490
568
{estimator_name : metric_definitions } if metric_definitions is not None else {}
491
569
)
492
- self .static_hyperparameters = None
493
570
else :
494
571
self .estimator = estimator
495
572
self .objective_metric_name = objective_metric_name
@@ -521,6 +598,31 @@ def __init__(
521
598
self .warm_start_config = warm_start_config
522
599
self .early_stopping_type = early_stopping_type
523
600
self .random_seed = random_seed
601
+ self .instance_configs_dict = None
602
+ self .instance_configs = None
603
+
604
+ def override_resource_config (
605
+ self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
606
+ ):
607
+ """Override the instance configuration of the estimators used by the tuner.
608
+
609
+ Args:
610
+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
611
+ The InstanceConfigs to use as an override for the instance configuration
612
+ of the estimator. ``None`` will remove the override.
613
+ """
614
+ if isinstance (instance_configs , dict ):
615
+ self ._validate_dict_argument (
616
+ name = "instance_configs" ,
617
+ value = instance_configs ,
618
+ allowed_keys = list (self .estimator_dict .keys ()),
619
+ )
620
+ self .instance_configs_dict = instance_configs
621
+ else :
622
+ self .instance_configs = instance_configs
623
+ if self .estimator_dict is not None and self .estimator_dict .keys ():
624
+ estimator_names = list (self .estimator_dict .keys ())
625
+ self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
524
626
525
627
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
526
628
"""Prepare the tuner instance for tuning (fit)."""
@@ -589,7 +691,6 @@ def _prepare_job_name_for_tuning(self, job_name=None):
589
691
590
692
def _prepare_static_hyperparameters_for_tuning (self , include_cls_metadata = False ):
591
693
"""Prepare static hyperparameters for all estimators before tuning."""
592
- self .static_hyperparameters = None
593
694
if self .estimator is not None :
594
695
self .static_hyperparameters = self ._prepare_static_hyperparameters (
595
696
self .estimator , self ._hyperparameter_ranges , include_cls_metadata
@@ -1817,6 +1918,7 @@ def _get_tuner_args(cls, tuner, inputs):
1817
1918
estimator = tuner .estimator ,
1818
1919
static_hyperparameters = tuner .static_hyperparameters ,
1819
1920
metric_definitions = tuner .metric_definitions ,
1921
+ instance_configs = tuner .instance_configs ,
1820
1922
)
1821
1923
1822
1924
if tuner .estimator_dict is not None :
@@ -1830,12 +1932,44 @@ def _get_tuner_args(cls, tuner, inputs):
1830
1932
tuner .objective_type ,
1831
1933
tuner .objective_metric_name_dict [estimator_name ],
1832
1934
tuner .hyperparameter_ranges_dict ()[estimator_name ],
1935
+ tuner .instance_configs_dict .get (estimator_name , None )
1936
+ if tuner .instance_configs_dict is not None
1937
+ else None ,
1833
1938
)
1834
1939
for estimator_name in sorted (tuner .estimator_dict .keys ())
1835
1940
]
1836
1941
1837
1942
return tuner_args
1838
1943
1944
+ @staticmethod
1945
+ def _prepare_hp_resource_config (
1946
+ instance_configs : List [InstanceConfig ],
1947
+ instance_count : int ,
1948
+ instance_type : str ,
1949
+ volume_size : int ,
1950
+ volume_kms_key : str ,
1951
+ ):
1952
+ """Placeholder hpo resource config for one estimator of the tuner."""
1953
+ resource_config = {}
1954
+ if volume_kms_key is not None :
1955
+ resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1956
+
1957
+ if instance_configs is None :
1958
+ resource_config ["InstanceCount" ] = instance_count
1959
+ resource_config ["InstanceType" ] = instance_type
1960
+ resource_config ["VolumeSizeInGB" ] = volume_size
1961
+ else :
1962
+ resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1963
+ instance_configs
1964
+ )
1965
+
1966
+ return resource_config
1967
+
1968
+ @staticmethod
1969
+ def _prepare_instance_configs (instance_configs : List [InstanceConfig ]):
1970
+ """Prepare instance config for create tuning request."""
1971
+ return [config .to_input_req () for config in instance_configs ]
1972
+
1839
1973
@staticmethod
1840
1974
def _prepare_training_config (
1841
1975
inputs ,
@@ -1846,10 +1980,20 @@ def _prepare_training_config(
1846
1980
objective_type = None ,
1847
1981
objective_metric_name = None ,
1848
1982
parameter_ranges = None ,
1983
+ instance_configs = None ,
1849
1984
):
1850
1985
"""Prepare training config for one estimator."""
1851
1986
training_config = _Job ._load_config (inputs , estimator )
1852
1987
1988
+ del training_config ["resource_config" ]
1989
+ training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1990
+ instance_configs ,
1991
+ estimator .instance_count ,
1992
+ estimator .instance_type ,
1993
+ estimator .volume_size ,
1994
+ estimator .volume_kms_key ,
1995
+ )
1996
+
1853
1997
training_config ["input_mode" ] = estimator .input_mode
1854
1998
training_config ["metric_definitions" ] = metric_definitions
1855
1999
0 commit comments