@@ -116,6 +116,7 @@ def __init__(
116
116
role : str ,
117
117
instance_count : Optional [Union [int , PipelineVariable ]] = None ,
118
118
instance_type : Optional [Union [str , PipelineVariable ]] = None ,
119
+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
119
120
volume_size : Union [int , PipelineVariable ] = 30 ,
120
121
volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
121
122
max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -167,6 +168,10 @@ def __init__(
167
168
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
168
169
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
169
170
not set.
171
+ keep_alive_period_in_seconds (int): How long in seconds (default: None)
172
+ will the resource including instances, volumes, ecr imges etc. used
173
+ by this training job be kept alive for reuse for the next follow-up
174
+ training job.
170
175
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
171
176
storing input and output data during training (default: 30).
172
177
@@ -510,6 +515,7 @@ def __init__(
510
515
self .role = role
511
516
self .instance_count = instance_count
512
517
self .instance_type = instance_type
518
+ self .keep_alive_period_in_seconds = keep_alive_period_in_seconds
513
519
self .instance_groups = instance_groups
514
520
self .volume_size = volume_size
515
521
self .volume_kms_key = volume_kms_key
@@ -1578,6 +1584,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
1578
1584
if "EnableNetworkIsolation" in job_details :
1579
1585
init_params ["enable_network_isolation" ] = job_details ["EnableNetworkIsolation" ]
1580
1586
1587
+ if "KeepAlivePeriodInSeconds" in job_details ["ResourceConfig" ]:
1588
+ init_params ["keep_alive_period_in_seconds" ] = job_details ["ResourceConfig" ][
1589
+ "keepAlivePeriodInSeconds"
1590
+ ]
1591
+
1581
1592
has_hps = "HyperParameters" in job_details
1582
1593
init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
1583
1594
@@ -2126,7 +2137,9 @@ def _is_local_channel(cls, input_uri):
2126
2137
return isinstance (input_uri , string_types ) and input_uri .startswith ("file://" )
2127
2138
2128
2139
@classmethod
2129
- def update (cls , estimator , profiler_rule_configs = None , profiler_config = None ):
2140
+ def update (
2141
+ cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2142
+ ):
2130
2143
"""Update a running Amazon SageMaker training job.
2131
2144
2132
2145
Args:
@@ -2135,18 +2148,21 @@ def update(cls, estimator, profiler_rule_configs=None, profiler_config=None):
2135
2148
updated in the training job. (default: None).
2136
2149
profiler_config (dict): Configuration for how profiling information is emitted with
2137
2150
SageMaker Debugger. (default: None).
2151
+ resource_config (dict): Configuration for resource of the training job. (default: None).
2138
2152
2139
2153
Returns:
2140
2154
sagemaker.estimator._TrainingJob: Constructed object that captures
2141
2155
all information about the updated training job.
2142
2156
"""
2143
- update_args = cls ._get_update_args (estimator , profiler_rule_configs , profiler_config )
2157
+ update_args = cls ._get_update_args (
2158
+ estimator , profiler_rule_configs , profiler_config , resource_config
2159
+ )
2144
2160
estimator .sagemaker_session .update_training_job (** update_args )
2145
2161
2146
2162
return estimator .latest_training_job
2147
2163
2148
2164
@classmethod
2149
- def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config ):
2165
+ def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config , resource_config ):
2150
2166
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
2151
2167
2152
2168
Args:
@@ -2156,13 +2172,15 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config):
2156
2172
updated in the training job. (default: None).
2157
2173
profiler_config (dict): Configuration for how profiling information is emitted with
2158
2174
SageMaker Debugger. (default: None).
2175
+ resource_config (dict): Configuration for resource of the training job. (default: None).
2159
2176
2160
2177
Returns:
2161
2178
Dict: dict for `sagemaker.session.Session.update_training_job` method
2162
2179
"""
2163
2180
update_args = {"job_name" : estimator .latest_training_job .name }
2164
2181
update_args .update (build_dict ("profiler_rule_configs" , profiler_rule_configs ))
2165
2182
update_args .update (build_dict ("profiler_config" , profiler_config ))
2183
+ update_args .update (build_dict ("resource_config" , resource_config ))
2166
2184
2167
2185
return update_args
2168
2186
@@ -2218,6 +2236,7 @@ def __init__(
2218
2236
role : str ,
2219
2237
instance_count : Optional [Union [int , PipelineVariable ]] = None ,
2220
2238
instance_type : Optional [Union [str , PipelineVariable ]] = None ,
2239
+ keep_alive_period_in_seconds : Optional [Union [int , PipelineVariable ]] = None ,
2221
2240
volume_size : Union [int , PipelineVariable ] = 30 ,
2222
2241
volume_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
2223
2242
max_run : Union [int , PipelineVariable ] = 24 * 60 * 60 ,
@@ -2270,6 +2289,10 @@ def __init__(
2270
2289
instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
2271
2290
for example, ``'ml.c4.xlarge'``. Required if instance_groups is
2272
2291
not set.
2292
+ keep_alive_period_in_seconds (int): How long in seconds (default: None)
2293
+ will the resource including instances, volumes, ecr imges etc. used
2294
+ by this training job be kept alive for reuse for the next follow-up
2295
+ training job.
2273
2296
volume_size (int or PipelineVariable): Size in GB of the storage volume to use for
2274
2297
storing input and output data during training (default: 30).
2275
2298
@@ -2591,6 +2614,7 @@ def __init__(
2591
2614
role ,
2592
2615
instance_count ,
2593
2616
instance_type ,
2617
+ keep_alive_period_in_seconds ,
2594
2618
volume_size ,
2595
2619
volume_kms_key ,
2596
2620
max_run ,
0 commit comments