@@ -460,6 +460,116 @@ def to_input_req(self):
460
460
}
461
461
462
462
463
+ class TuningJobCompletionCriteriaConfig (object ):
464
+ """The configuration for a job completion criteria."""
465
+
466
+ def __init__ (
467
+ self ,
468
+ max_number_of_training_jobs_not_improving : int = None ,
469
+ complete_on_convergence : bool = None ,
470
+ target_objective_metric_value : float = None ,
471
+ ):
472
+ """Creates a ``TuningJobCompletionCriteriaConfig`` with provided criteria.
473
+
474
+ Args:
475
+ max_number_of_training_jobs_not_improving (int): The number of training jobs that do not
476
+ improve the best objective after which tuning job will stop.
477
+ complete_on_convergence (bool): A flag to stop your hyperparameter tuning job if
478
+ automatic model tuning (AMT) has detected that your model has converged as evaluated
479
+ against your objective function.
480
+ target_objective_metric_value (float): The value of the objective metric.
481
+ """
482
+
483
+ self .max_number_of_training_jobs_not_improving = max_number_of_training_jobs_not_improving
484
+ self .complete_on_convergence = complete_on_convergence
485
+ self .target_objective_metric_value = target_objective_metric_value
486
+
487
+ @classmethod
488
+ def from_job_desc (cls , completion_criteria_config ):
489
+ """Creates a ``TuningJobCompletionCriteriaConfig`` from a configuration response.
490
+
491
+ This is the completion criteria configuration from the DescribeTuningJob response.
492
+ Args:
493
+ completion_criteria_config (dict): The expected format of the
494
+ ``completion_criteria_config`` contains three first-class fields
495
+
496
+ Returns:
497
+ sagemaker.tuner.TuningJobCompletionCriteriaConfig: De-serialized instance of
498
+ TuningJobCompletionCriteriaConfig containing the completion criteria.
499
+ """
500
+ complete_on_convergence = None
501
+ if CONVERGENCE_DETECTED in completion_criteria_config :
502
+ if completion_criteria_config [CONVERGENCE_DETECTED ][COMPLETE_ON_CONVERGENCE_DETECTED ]:
503
+ complete_on_convergence = bool (
504
+ completion_criteria_config [CONVERGENCE_DETECTED ][
505
+ COMPLETE_ON_CONVERGENCE_DETECTED
506
+ ]
507
+ == "Enabled"
508
+ )
509
+
510
+ max_number_of_training_jobs_not_improving = None
511
+ if BEST_OBJECTIVE_NOT_IMPROVING in completion_criteria_config :
512
+ if completion_criteria_config [BEST_OBJECTIVE_NOT_IMPROVING ][
513
+ MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
514
+ ]:
515
+ max_number_of_training_jobs_not_improving = completion_criteria_config [
516
+ BEST_OBJECTIVE_NOT_IMPROVING
517
+ ][MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING ]
518
+
519
+ target_objective_metric_value = None
520
+ if TARGET_OBJECTIVE_METRIC_VALUE in completion_criteria_config :
521
+ target_objective_metric_value = completion_criteria_config [
522
+ TARGET_OBJECTIVE_METRIC_VALUE
523
+ ]
524
+
525
+ return cls (
526
+ max_number_of_training_jobs_not_improving = max_number_of_training_jobs_not_improving ,
527
+ complete_on_convergence = complete_on_convergence ,
528
+ target_objective_metric_value = target_objective_metric_value ,
529
+ )
530
+
531
+ def to_input_req (self ):
532
+ """Converts the ``self`` instance to the desired input request format.
533
+
534
+ Examples:
535
+ >>> completion_criteria_config = TuningJobCompletionCriteriaConfig(
536
+ max_number_of_training_jobs_not_improving=5
537
+ complete_on_convergence = True,
538
+ target_objective_metric_value = 0.42
539
+ )
540
+ >>> completion_criteria_config.to_input_req()
541
+ {
542
+ "BestObjectiveNotImproving": {
543
+ "MaxNumberOfTrainingJobsNotImproving":5
544
+ },
545
+ "ConvergenceDetected": {
546
+ "CompleteOnConvergence": "Enabled",
547
+ },
548
+ "TargetObjectiveMetricValue": 0.42
549
+ }
550
+
551
+ Returns:
552
+ dict: Containing the completion criteria configurations.
553
+ """
554
+ completion_criteria_config = {}
555
+ if self .max_number_of_training_jobs_not_improving is not None :
556
+ completion_criteria_config [BEST_OBJECTIVE_NOT_IMPROVING ][
557
+ MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
558
+ ] = self .max_number_of_training_jobs_not_improving
559
+
560
+ if self .target_objective_metric_value is not None :
561
+ completion_criteria_config [
562
+ TARGET_OBJECTIVE_METRIC_VALUE
563
+ ] = self .target_objective_metric_value
564
+
565
+ if self .complete_on_convergence is not None :
566
+ completion_criteria_config [CONVERGENCE_DETECTED ][COMPLETE_ON_CONVERGENCE_DETECTED ] = (
567
+ "Enabled" if self .complete_on_convergence else "Disabled"
568
+ )
569
+
570
+ return completion_criteria_config
571
+
572
+
463
573
class HyperparameterTuner (object ):
464
574
"""Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
465
575
@@ -559,14 +669,14 @@ def __init__(
559
669
self .estimator = None
560
670
self .objective_metric_name = None
561
671
self ._hyperparameter_ranges = None
562
- self .static_hyperparameters = None
563
672
self .metric_definitions = None
564
673
self .estimator_dict = {estimator_name : estimator }
565
674
self .objective_metric_name_dict = {estimator_name : objective_metric_name }
566
675
self ._hyperparameter_ranges_dict = {estimator_name : hyperparameter_ranges }
567
676
self .metric_definitions_dict = (
568
677
{estimator_name : metric_definitions } if metric_definitions is not None else {}
569
678
)
679
+ self .static_hyperparameters = None
570
680
else :
571
681
self .estimator = estimator
572
682
self .objective_metric_name = objective_metric_name
@@ -598,31 +708,6 @@ def __init__(
598
708
self .warm_start_config = warm_start_config
599
709
self .early_stopping_type = early_stopping_type
600
710
self .random_seed = random_seed
601
- self .instance_configs_dict = None
602
- self .instance_configs = None
603
-
604
- def override_resource_config (
605
- self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
606
- ):
607
- """Override the instance configuration of the estimators used by the tuner.
608
-
609
- Args:
610
- instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
611
- The InstanceConfigs to use as an override for the instance configuration
612
- of the estimator. ``None`` will remove the override.
613
- """
614
- if isinstance (instance_configs , dict ):
615
- self ._validate_dict_argument (
616
- name = "instance_configs" ,
617
- value = instance_configs ,
618
- allowed_keys = list (self .estimator_dict .keys ()),
619
- )
620
- self .instance_configs_dict = instance_configs
621
- else :
622
- self .instance_configs = instance_configs
623
- if self .estimator_dict is not None and self .estimator_dict .keys ():
624
- estimator_names = list (self .estimator_dict .keys ())
625
- self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
626
711
627
712
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
628
713
"""Prepare the tuner instance for tuning (fit)."""
@@ -691,6 +776,7 @@ def _prepare_job_name_for_tuning(self, job_name=None):
691
776
692
777
def _prepare_static_hyperparameters_for_tuning (self , include_cls_metadata = False ):
693
778
"""Prepare static hyperparameters for all estimators before tuning."""
779
+ self .static_hyperparameters = None
694
780
if self .estimator is not None :
695
781
self .static_hyperparameters = self ._prepare_static_hyperparameters (
696
782
self .estimator , self ._hyperparameter_ranges , include_cls_metadata
@@ -1918,7 +2004,6 @@ def _get_tuner_args(cls, tuner, inputs):
1918
2004
estimator = tuner .estimator ,
1919
2005
static_hyperparameters = tuner .static_hyperparameters ,
1920
2006
metric_definitions = tuner .metric_definitions ,
1921
- instance_configs = tuner .instance_configs ,
1922
2007
)
1923
2008
1924
2009
if tuner .estimator_dict is not None :
@@ -1932,44 +2017,12 @@ def _get_tuner_args(cls, tuner, inputs):
1932
2017
tuner .objective_type ,
1933
2018
tuner .objective_metric_name_dict [estimator_name ],
1934
2019
tuner .hyperparameter_ranges_dict ()[estimator_name ],
1935
- tuner .instance_configs_dict .get (estimator_name , None )
1936
- if tuner .instance_configs_dict is not None
1937
- else None ,
1938
2020
)
1939
2021
for estimator_name in sorted (tuner .estimator_dict .keys ())
1940
2022
]
1941
2023
1942
2024
return tuner_args
1943
2025
1944
- @staticmethod
1945
- def _prepare_hp_resource_config (
1946
- instance_configs : List [InstanceConfig ],
1947
- instance_count : int ,
1948
- instance_type : str ,
1949
- volume_size : int ,
1950
- volume_kms_key : str ,
1951
- ):
1952
- """Placeholder hpo resource config for one estimator of the tuner."""
1953
- resource_config = {}
1954
- if volume_kms_key is not None :
1955
- resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1956
-
1957
- if instance_configs is None :
1958
- resource_config ["InstanceCount" ] = instance_count
1959
- resource_config ["InstanceType" ] = instance_type
1960
- resource_config ["VolumeSizeInGB" ] = volume_size
1961
- else :
1962
- resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1963
- instance_configs
1964
- )
1965
-
1966
- return resource_config
1967
-
1968
- @staticmethod
1969
- def _prepare_instance_configs (instance_configs : List [InstanceConfig ]):
1970
- """Prepare instance config for create tuning request."""
1971
- return [config .to_input_req () for config in instance_configs ]
1972
-
1973
2026
@staticmethod
1974
2027
def _prepare_training_config (
1975
2028
inputs ,
@@ -1980,20 +2033,10 @@ def _prepare_training_config(
1980
2033
objective_type = None ,
1981
2034
objective_metric_name = None ,
1982
2035
parameter_ranges = None ,
1983
- instance_configs = None ,
1984
2036
):
1985
2037
"""Prepare training config for one estimator."""
1986
2038
training_config = _Job ._load_config (inputs , estimator )
1987
2039
1988
- del training_config ["resource_config" ]
1989
- training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1990
- instance_configs ,
1991
- estimator .instance_count ,
1992
- estimator .instance_type ,
1993
- estimator .volume_size ,
1994
- estimator .volume_kms_key ,
1995
- )
1996
-
1997
2040
training_config ["input_mode" ] = estimator .input_mode
1998
2041
training_config ["metric_definitions" ] = metric_definitions
1999
2042
0 commit comments