@@ -413,6 +413,7 @@ def __init__(
413
413
strategy_config : Optional [StrategyConfig ] = None ,
414
414
early_stopping_type : Union [str , PipelineVariable ] = "Off" ,
415
415
estimator_name : Optional [str ] = None ,
416
+ random_seed : Optional [int ] = None ,
416
417
):
417
418
"""Creates a ``HyperparameterTuner`` instance.
418
419
@@ -470,6 +471,9 @@ def __init__(
470
471
estimator_name (str): A unique name to identify an estimator within the
471
472
hyperparameter tuning job, when more than one estimator is used with
472
473
the same tuning job (default: None).
474
+ random_seed (int): An initial value used to initialize a pseudo-random number generator.
475
+ Setting a random seed will make the hyperparameter tuning search strategies to
476
+ produce more consistent configurations for the same tuning job.
473
477
"""
474
478
if hyperparameter_ranges is None or len (hyperparameter_ranges ) == 0 :
475
479
raise ValueError ("Need to specify hyperparameter ranges" )
@@ -516,6 +520,7 @@ def __init__(
516
520
self .latest_tuning_job = None
517
521
self .warm_start_config = warm_start_config
518
522
self .early_stopping_type = early_stopping_type
523
+ self .random_seed = random_seed
519
524
520
525
def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
521
526
"""Prepare the tuner instance for tuning (fit)."""
@@ -1222,6 +1227,9 @@ def _prepare_init_params_from_job_description(cls, job_details):
1222
1227
"base_tuning_job_name" : base_from_name (job_details ["HyperParameterTuningJobName" ]),
1223
1228
}
1224
1229
1230
+ if "RandomSeed" in tuning_config :
1231
+ params ["random_seed" ] = tuning_config ["RandomSeed" ]
1232
+
1225
1233
if "HyperParameterTuningJobObjective" in tuning_config :
1226
1234
params ["objective_metric_name" ] = tuning_config ["HyperParameterTuningJobObjective" ][
1227
1235
"MetricName"
@@ -1483,6 +1491,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
1483
1491
warm_start_type = warm_start_type , parents = all_parents
1484
1492
),
1485
1493
early_stopping_type = self .early_stopping_type ,
1494
+ random_seed = self .random_seed ,
1486
1495
)
1487
1496
1488
1497
if len (self .estimator_dict ) > 1 :
@@ -1508,6 +1517,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
1508
1517
max_parallel_jobs = self .max_parallel_jobs ,
1509
1518
warm_start_config = WarmStartConfig (warm_start_type = warm_start_type , parents = all_parents ),
1510
1519
early_stopping_type = self .early_stopping_type ,
1520
+ random_seed = self .random_seed ,
1511
1521
)
1512
1522
1513
1523
@classmethod
@@ -1526,6 +1536,7 @@ def create(
1526
1536
tags = None ,
1527
1537
warm_start_config = None ,
1528
1538
early_stopping_type = "Off" ,
1539
+ random_seed = None ,
1529
1540
):
1530
1541
"""Factory method to create a ``HyperparameterTuner`` instance.
1531
1542
@@ -1586,6 +1597,9 @@ def create(
1586
1597
Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping
1587
1598
will not be attempted. If set to 'Auto', early stopping of some training jobs may
1588
1599
happen, but is not guaranteed to.
1600
+ random_seed (int): An initial value used to initialize a pseudo-random number generator.
1601
+ Setting a random seed will make the hyperparameter tuning search strategies to
1602
+ produce more consistent configurations for the same tuning job.
1589
1603
1590
1604
Returns:
1591
1605
sagemaker.tuner.HyperparameterTuner: a new ``HyperparameterTuner`` object that can
@@ -1624,6 +1638,7 @@ def create(
1624
1638
tags = tags ,
1625
1639
warm_start_config = warm_start_config ,
1626
1640
early_stopping_type = early_stopping_type ,
1641
+ random_seed = random_seed ,
1627
1642
)
1628
1643
1629
1644
for estimator_name in estimator_names [1 :]:
@@ -1775,6 +1790,9 @@ def _get_tuner_args(cls, tuner, inputs):
1775
1790
"early_stopping_type" : tuner .early_stopping_type ,
1776
1791
}
1777
1792
1793
+ if tuner .random_seed is not None :
1794
+ tuning_config ["random_seed" ] = tuner .random_seed
1795
+
1778
1796
if tuner .strategy_config is not None :
1779
1797
tuning_config ["strategy_config" ] = tuner .strategy_config .to_input_req ()
1780
1798
0 commit comments