@@ -116,6 +116,7 @@ def __init__(
116
116
role : str ,
117
117
instance_count : Optional [Union [int , PipelineVariable ]] = None ,
118
118
instance_type : Optional [Union [str , PipelineVariable ]] = None ,
119
+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
119
120
volume_size : Union [int , PipelineVariable ] = 30 ,
120
121
volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
121
122
max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -167,6 +168,9 @@ def __init__(
167
168
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
168
169
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
169
170
not set.
171
+ keep_alive_period_in_seconds (int): The duration of time in seconds
172
+ to retain configured resources in a warm pool for subsequent
173
+ training jobs (default: None).
170
174
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
171
175
storing input and output data during training (default: 30).
172
176
@@ -510,6 +514,7 @@ def __init__(
510
514
self .role = role
511
515
self .instance_count = instance_count
512
516
self .instance_type = instance_type
517
+ self .keep_alive_period_in_seconds = keep_alive_period_in_seconds
513
518
self .instance_groups = instance_groups
514
519
self .volume_size = volume_size
515
520
self .volume_kms_key = volume_kms_key
@@ -1578,6 +1583,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
1578
1583
if "EnableNetworkIsolation" in job_details :
1579
1584
init_params ["enable_network_isolation" ] = job_details ["EnableNetworkIsolation" ]
1580
1585
1586
+ if "KeepAlivePeriodInSeconds" in job_details ["ResourceConfig" ]:
1587
+ init_params ["keep_alive_period_in_seconds" ] = job_details ["ResourceConfig" ][
1588
+ "keepAlivePeriodInSeconds"
1589
+ ]
1590
+
1581
1591
has_hps = "HyperParameters" in job_details
1582
1592
init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
1583
1593
@@ -2126,7 +2136,9 @@ def _is_local_channel(cls, input_uri):
2126
2136
return isinstance (input_uri , string_types ) and input_uri .startswith ("file://" )
2127
2137
2128
2138
@classmethod
2129
- def update (cls , estimator , profiler_rule_configs = None , profiler_config = None ):
2139
+ def update (
2140
+ cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2141
+ ):
2130
2142
"""Update a running Amazon SageMaker training job.
2131
2143
2132
2144
Args:
@@ -2135,18 +2147,23 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
2135
2147
updated in the training job. (default: None).
2136
2148
profiler_config (dict): Configuration for how profiling information is emitted with
2137
2149
SageMaker Debugger. (default: None).
2150
+ resource_config (dict): Configuration of the resources for the training job. You can
2151
+ update the keep-alive period if the warm pool status is `Available`. No other fields
2152
+ can be updated. (default: None).
2138
2153
2139
2154
Returns:
2140
2155
sagemaker.estimator._TrainingJob: Constructed object that captures
2141
2156
all information about the updated training job.
2142
2157
"""
2143
- update_args = cls ._get_update_args (estimator , profiler_rule_configs , profiler_config )
2158
+ update_args = cls ._get_update_args (
2159
+ estimator , profiler_rule_configs , profiler_config , resource_config
2160
+ )
2144
2161
estimator .sagemaker_session .update_training_job (** update_args )
2145
2162
2146
2163
return estimator .latest_training_job
2147
2164
2148
2165
@classmethod
2149
- def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config ):
2166
+ def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config , resource_config ):
2150
2167
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
2151
2168
2152
2169
Args:
@@ -2156,13 +2173,17 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
2156
2173
updated in the training job. (default: None).
2157
2174
profiler_config (dict): Configuration for how profiling information is emitted with
2158
2175
SageMaker Debugger. (default: None).
2176
+ resource_config (dict): Configuration of the resources for the training job. You can
2177
+ update the keep-alive period if the warm pool status is `Available`. No other fields
2178
+ can be updated. (default: None).
2159
2179
2160
2180
Returns:
2161
2181
Dict: dict for `sagemaker.session.Session.update_training_job` method
2162
2182
"""
2163
2183
update_args = {"job_name" : estimator .latest_training_job .name }
2164
2184
update_args .update (build_dict ("profiler_rule_configs" , profiler_rule_configs ))
2165
2185
update_args .update (build_dict ("profiler_config" , profiler_config ))
2186
+ update_args .update (build_dict ("resource_config" , resource_config ))
2166
2187
2167
2188
return update_args
2168
2189
@@ -2218,6 +2239,7 @@ def __init__(
2218
2239
role : str ,
2219
2240
instance_count : Optional [Union [int , PipelineVariable ]] = None ,
2220
2241
instance_type : Optional [Union [str , PipelineVariable ]] = None ,
2242
+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
2221
2243
volume_size : Union [int , PipelineVariable ] = 30 ,
2222
2244
volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
2223
2245
max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -2270,6 +2292,9 @@ def __init__(
2270
2292
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
2271
2293
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
2272
2294
not set.
2295
+ keep_alive_period_in_seconds (int): The duration of time in seconds
2296
+ to retain configured resources in a warm pool for subsequent
2297
+ training jobs (default: None).
2273
2298
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
2274
2299
storing input and output data during training (default: 30).
2275
2300
@@ -2591,6 +2616,7 @@ def __init__(
2591
2616
role ,
2592
2617
instance_count ,
2593
2618
instance_type ,
2619
+ keep_alive_period_in_seconds ,
2594
2620
volume_size ,
2595
2621
volume_kms_key ,
2596
2622
max_run ,
0 commit comments