@@ -178,6 +178,7 @@ def __init__(
178
178
container_entry_point : Optional [List [str ]] = None ,
179
179
container_arguments : Optional [List [str ]] = None ,
180
180
disable_output_compression : bool = False ,
181
+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
181
182
** kwargs ,
182
183
):
183
184
"""Initialize an ``EstimatorBase`` instance.
@@ -540,6 +541,8 @@ def __init__(
540
541
to Amazon S3 without compression after training finishes.
541
542
enable_infra_check (bool or PipelineVariable): Optional.
542
543
Specifies whether it is running Sagemaker built-in infra check jobs.
544
+ enable_remote_debug (bool or PipelineVariable): Optional.
545
+ Specifies whether RemoteDebug is enabled for the training job
543
546
"""
544
547
instance_count = renamed_kwargs (
545
548
"train_instance_count" , "instance_count" , instance_count , kwargs
@@ -777,6 +780,8 @@ def __init__(
777
780
778
781
self .tensorboard_app = TensorBoardApp (region = self .sagemaker_session .boto_region_name )
779
782
783
+ self ._enable_remote_debug = enable_remote_debug
784
+
780
785
@abstractmethod
781
786
def training_image_uri (self ):
782
787
"""Return the Docker image to use for training.
@@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
1958
1963
max_wait = job_details .get ("StoppingCondition" , {}).get ("MaxWaitTimeInSeconds" )
1959
1964
if max_wait :
1960
1965
init_params ["max_wait" ] = max_wait
1966
+
1967
+ if "RemoteDebugConfig" in job_details :
1968
+ init_params ["enable_remote_debug" ] = job_details ["RemoteDebugConfig" ].get (
1969
+ "EnableRemoteDebug"
1970
+ )
1961
1971
return init_params
1962
1972
1963
1973
def _get_instance_type (self ):
@@ -2292,6 +2302,32 @@ def update_profiler(
2292
2302
2293
2303
_TrainingJob .update (self , profiler_rule_configs , profiler_config_request_dict )
2294
2304
2305
+ def get_remote_debug_config (self ):
2306
+ """dict: Return the configuration of RemoteDebug"""
2307
+ return (
2308
+ None
2309
+ if self ._enable_remote_debug is None
2310
+ else {"EnableRemoteDebug" : self ._enable_remote_debug }
2311
+ )
2312
+
2313
+ def enable_remote_debug (self ):
2314
+ """Enable remote debug for a training job."""
2315
+ self ._update_remote_debug (True )
2316
+
2317
+ def disable_remote_debug (self ):
2318
+ """Disable remote debug for a training job."""
2319
+ self ._update_remote_debug (False )
2320
+
2321
+ def _update_remote_debug (self , enable_remote_debug : bool ):
2322
+ """Update to enable or disable remote debug for a training job.
2323
+
2324
+ This method updates the ``_enable_remote_debug`` parameter
2325
+ and enables or disables remote debug for a training job
2326
+ """
2327
+ self ._ensure_latest_training_job ()
2328
+ _TrainingJob .update (self , remote_debug_config = {"EnableRemoteDebug" : enable_remote_debug })
2329
+ self ._enable_remote_debug = enable_remote_debug
2330
+
2295
2331
def get_app_url (
2296
2332
self ,
2297
2333
app_type ,
@@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
2520
2556
if estimator .profiler_config :
2521
2557
train_args ["profiler_config" ] = estimator .profiler_config ._to_request_dict ()
2522
2558
2559
+ if estimator .get_remote_debug_config () is not None :
2560
+ train_args ["remote_debug_config" ] = estimator .get_remote_debug_config ()
2561
+
2523
2562
return train_args
2524
2563
2525
2564
@classmethod
@@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri):
2549
2588
2550
2589
@classmethod
2551
2590
def update (
2552
- cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2591
+ cls ,
2592
+ estimator ,
2593
+ profiler_rule_configs = None ,
2594
+ profiler_config = None ,
2595
+ resource_config = None ,
2596
+ remote_debug_config = None ,
2553
2597
):
2554
2598
"""Update a running Amazon SageMaker training job.
2555
2599
@@ -2562,20 +2606,31 @@ def update(
2562
2606
resource_config (dict): Configuration of the resources for the training job. You can
2563
2607
update the keep-alive period if the warm pool status is `Available`. No other fields
2564
2608
can be updated. (default: None).
2609
+ remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
2610
+ The dict can contain 'EnableRemoteDebug'(bool).
2611
+ For example,
2612
+
2613
+ .. code:: python
2614
+
2615
+ remote_debug_config = {
2616
+ "EnableRemoteDebug": True,
2617
+ } (default: None).
2565
2618
2566
2619
Returns:
2567
2620
sagemaker.estimator._TrainingJob: Constructed object that captures
2568
2621
all information about the updated training job.
2569
2622
"""
2570
2623
update_args = cls ._get_update_args (
2571
- estimator , profiler_rule_configs , profiler_config , resource_config
2624
+ estimator , profiler_rule_configs , profiler_config , resource_config , remote_debug_config
2572
2625
)
2573
2626
estimator .sagemaker_session .update_training_job (** update_args )
2574
2627
2575
2628
return estimator .latest_training_job
2576
2629
2577
2630
@classmethod
2578
- def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config , resource_config ):
2631
+ def _get_update_args (
2632
+ cls , estimator , profiler_rule_configs , profiler_config , resource_config , remote_debug_config
2633
+ ):
2579
2634
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
2580
2635
2581
2636
Args:
@@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
2596
2651
update_args .update (build_dict ("profiler_rule_configs" , profiler_rule_configs ))
2597
2652
update_args .update (build_dict ("profiler_config" , profiler_config ))
2598
2653
update_args .update (build_dict ("resource_config" , resource_config ))
2654
+ update_args .update (build_dict ("remote_debug_config" , remote_debug_config ))
2599
2655
2600
2656
return update_args
2601
2657
@@ -2694,6 +2750,7 @@ def __init__(
2694
2750
container_arguments : Optional [List [str ]] = None ,
2695
2751
disable_output_compression : bool = False ,
2696
2752
enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
2753
+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
2697
2754
** kwargs ,
2698
2755
):
2699
2756
"""Initialize an ``Estimator`` instance.
@@ -3055,6 +3112,8 @@ def __init__(
3055
3112
to Amazon S3 without compression after training finishes.
3056
3113
enable_infra_check (bool or PipelineVariable): Optional.
3057
3114
Specifies whether it is running Sagemaker built-in infra check jobs.
3115
+ enable_remote_debug (bool or PipelineVariable): Optional.
3116
+ Specifies whether RemoteDebug is enabled for the training job
3058
3117
"""
3059
3118
self .image_uri = image_uri
3060
3119
self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3106,6 +3165,7 @@ def __init__(
3106
3165
container_entry_point = container_entry_point ,
3107
3166
container_arguments = container_arguments ,
3108
3167
disable_output_compression = disable_output_compression ,
3168
+ enable_remote_debug = enable_remote_debug ,
3109
3169
** kwargs ,
3110
3170
)
3111
3171
0 commit comments