Skip to content

Commit 8110cc4

Browse files
xxycrabXinyu Xieevakravi
authored andcommitted
feature: support remote debug for sagemaker training job (aws#4315)
* feature: support remote debug for sagemaker training job * change: Replace update_remote_config with 2 helper methods for enable and disable respectively * change: add new argument enable_remote_debug to skip set of test_jumpstart_estimator_kwargs_match_parent_class * chore: add jumpstart support for remote debug --------- Co-authored-by: Xinyu Xie <[email protected]> Co-authored-by: Evan Kravitz <[email protected]>
1 parent f8629d7 commit 8110cc4

File tree

7 files changed

+208
-3
lines changed

7 files changed

+208
-3
lines changed

src/sagemaker/estimator.py

+63-3
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(
178178
container_entry_point: Optional[List[str]] = None,
179179
container_arguments: Optional[List[str]] = None,
180180
disable_output_compression: bool = False,
181+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
181182
**kwargs,
182183
):
183184
"""Initialize an ``EstimatorBase`` instance.
@@ -540,6 +541,8 @@ def __init__(
540541
to Amazon S3 without compression after training finishes.
541542
enable_infra_check (bool or PipelineVariable): Optional.
542543
Specifies whether it is running Sagemaker built-in infra check jobs.
544+
enable_remote_debug (bool or PipelineVariable): Optional.
545+
Specifies whether RemoteDebug is enabled for the training job
543546
"""
544547
instance_count = renamed_kwargs(
545548
"train_instance_count", "instance_count", instance_count, kwargs
@@ -777,6 +780,8 @@ def __init__(
777780

778781
self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)
779782

783+
self._enable_remote_debug = enable_remote_debug
784+
780785
@abstractmethod
781786
def training_image_uri(self):
782787
"""Return the Docker image to use for training.
@@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19581963
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
19591964
if max_wait:
19601965
init_params["max_wait"] = max_wait
1966+
1967+
if "RemoteDebugConfig" in job_details:
1968+
init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
1969+
"EnableRemoteDebug"
1970+
)
19611971
return init_params
19621972

19631973
def _get_instance_type(self):
@@ -2292,6 +2302,32 @@ def update_profiler(
22922302

22932303
_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)
22942304

2305+
def get_remote_debug_config(self):
2306+
"""dict: Return the configuration of RemoteDebug"""
2307+
return (
2308+
None
2309+
if self._enable_remote_debug is None
2310+
else {"EnableRemoteDebug": self._enable_remote_debug}
2311+
)
2312+
2313+
def enable_remote_debug(self):
2314+
"""Enable remote debug for a training job."""
2315+
self._update_remote_debug(True)
2316+
2317+
def disable_remote_debug(self):
2318+
"""Disable remote debug for a training job."""
2319+
self._update_remote_debug(False)
2320+
2321+
def _update_remote_debug(self, enable_remote_debug: bool):
2322+
"""Update to enable or disable remote debug for a training job.
2323+
2324+
This method updates the ``_enable_remote_debug`` parameter
2325+
and enables or disables remote debug for a training job
2326+
"""
2327+
self._ensure_latest_training_job()
2328+
_TrainingJob.update(self, remote_debug_config={"EnableRemoteDebug": enable_remote_debug})
2329+
self._enable_remote_debug = enable_remote_debug
2330+
22952331
def get_app_url(
22962332
self,
22972333
app_type,
@@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25202556
if estimator.profiler_config:
25212557
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()
25222558

2559+
if estimator.get_remote_debug_config() is not None:
2560+
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
2561+
25232562
return train_args
25242563

25252564
@classmethod
@@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri):
25492588

25502589
@classmethod
25512590
def update(
2552-
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2591+
cls,
2592+
estimator,
2593+
profiler_rule_configs=None,
2594+
profiler_config=None,
2595+
resource_config=None,
2596+
remote_debug_config=None,
25532597
):
25542598
"""Update a running Amazon SageMaker training job.
25552599
@@ -2562,20 +2606,31 @@ def update(
25622606
resource_config (dict): Configuration of the resources for the training job. You can
25632607
update the keep-alive period if the warm pool status is `Available`. No other fields
25642608
can be updated. (default: None).
2609+
remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
2610+
The dict can contain 'EnableRemoteDebug'(bool).
2611+
For example,
2612+
2613+
.. code:: python
2614+
2615+
remote_debug_config = {
2616+
"EnableRemoteDebug": True,
2617+
} (default: None).
25652618
25662619
Returns:
25672620
sagemaker.estimator._TrainingJob: Constructed object that captures
25682621
all information about the updated training job.
25692622
"""
25702623
update_args = cls._get_update_args(
2571-
estimator, profiler_rule_configs, profiler_config, resource_config
2624+
estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
25722625
)
25732626
estimator.sagemaker_session.update_training_job(**update_args)
25742627

25752628
return estimator.latest_training_job
25762629

25772630
@classmethod
2578-
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
2631+
def _get_update_args(
2632+
cls, estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
2633+
):
25792634
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
25802635
25812636
Args:
@@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
25962651
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
25972652
update_args.update(build_dict("profiler_config", profiler_config))
25982653
update_args.update(build_dict("resource_config", resource_config))
2654+
update_args.update(build_dict("remote_debug_config", remote_debug_config))
25992655

26002656
return update_args
26012657

@@ -2694,6 +2750,7 @@ def __init__(
26942750
container_arguments: Optional[List[str]] = None,
26952751
disable_output_compression: bool = False,
26962752
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
2753+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
26972754
**kwargs,
26982755
):
26992756
"""Initialize an ``Estimator`` instance.
@@ -3055,6 +3112,8 @@ def __init__(
30553112
to Amazon S3 without compression after training finishes.
30563113
enable_infra_check (bool or PipelineVariable): Optional.
30573114
Specifies whether it is running Sagemaker built-in infra check jobs.
3115+
enable_remote_debug (bool or PipelineVariable): Optional.
3116+
Specifies whether RemoteDebug is enabled for the training job
30583117
"""
30593118
self.image_uri = image_uri
30603119
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3106,6 +3165,7 @@ def __init__(
31063165
container_entry_point=container_entry_point,
31073166
container_arguments=container_arguments,
31083167
disable_output_compression=disable_output_compression,
3168+
enable_remote_debug=enable_remote_debug,
31093169
**kwargs,
31103170
)
31113171

src/sagemaker/jumpstart/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
container_entry_point: Optional[List[str]] = None,
107107
container_arguments: Optional[List[str]] = None,
108108
disable_output_compression: Optional[bool] = None,
109+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
109110
):
110111
"""Initializes a ``JumpStartEstimator``.
111112
@@ -495,6 +496,8 @@ def __init__(
495496
a training job.
496497
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
497498
to Amazon S3 without compression after training finishes.
499+
enable_remote_debug (bool or PipelineVariable): Optional.
500+
Specifies whether RemoteDebug is enabled for the training job
498501
499502
Raises:
500503
ValueError: If the model ID is not recognized by JumpStart.
@@ -569,6 +572,7 @@ def _is_valid_model_id_hook():
569572
container_arguments=container_arguments,
570573
disable_output_compression=disable_output_compression,
571574
enable_infra_check=enable_infra_check,
575+
enable_remote_debug=enable_remote_debug,
572576
)
573577

574578
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def get_init_kwargs(
127127
container_arguments: Optional[List[str]] = None,
128128
disable_output_compression: Optional[bool] = None,
129129
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
130+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
130131
) -> JumpStartEstimatorInitKwargs:
131132
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
132133

@@ -183,6 +184,7 @@ def get_init_kwargs(
183184
container_arguments=container_arguments,
184185
disable_output_compression=disable_output_compression,
185186
enable_infra_check=enable_infra_check,
187+
enable_remote_debug=enable_remote_debug,
186188
)
187189

188190
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
12801280
"container_arguments",
12811281
"disable_output_compression",
12821282
"enable_infra_check",
1283+
"enable_remote_debug",
12831284
]
12841285

12851286
SERIALIZATION_EXCLUSION_SET = {
@@ -1344,6 +1345,7 @@ def __init__(
13441345
container_arguments: Optional[List[str]] = None,
13451346
disable_output_compression: Optional[bool] = None,
13461347
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
1348+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
13471349
) -> None:
13481350
"""Instantiates JumpStartEstimatorInitKwargs object."""
13491351

@@ -1401,6 +1403,7 @@ def __init__(
14011403
self.container_arguments = container_arguments
14021404
self.disable_output_compression = disable_output_compression
14031405
self.enable_infra_check = enable_infra_check
1406+
self.enable_remote_debug = enable_remote_debug
14041407

14051408

14061409
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

src/sagemaker/session.py

+48
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ def train( # noqa: C901
748748
profiler_config=None,
749749
environment: Optional[Dict[str, str]] = None,
750750
retry_strategy=None,
751+
remote_debug_config=None,
751752
):
752753
"""Create an Amazon SageMaker training job.
753754
@@ -858,6 +859,15 @@ def train( # noqa: C901
858859
configurations.src/sagemaker/lineage/artifact.py:285
859860
profiler_config (dict): Configuration for how profiling information is emitted
860861
with SageMaker Profiler. (default: ``None``).
862+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
863+
The dict can contain 'EnableRemoteDebug'(bool).
864+
For example,
865+
866+
.. code:: python
867+
868+
remote_debug_config = {
869+
"EnableRemoteDebug": True,
870+
}
861871
environment (dict[str, str]) : Environment variables to be set for
862872
use during training job (default: ``None``)
863873
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -950,6 +960,7 @@ def train( # noqa: C901
950960
enable_sagemaker_metrics=enable_sagemaker_metrics,
951961
profiler_rule_configs=profiler_rule_configs,
952962
profiler_config=inferred_profiler_config,
963+
remote_debug_config=remote_debug_config,
953964
environment=environment,
954965
retry_strategy=retry_strategy,
955966
)
@@ -992,6 +1003,7 @@ def _get_train_request( # noqa: C901
9921003
enable_sagemaker_metrics=None,
9931004
profiler_rule_configs=None,
9941005
profiler_config=None,
1006+
remote_debug_config=None,
9951007
environment=None,
9961008
retry_strategy=None,
9971009
):
@@ -1103,6 +1115,15 @@ def _get_train_request( # noqa: C901
11031115
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
11041116
profiler_config(dict): Configuration for how profiling information is emitted with
11051117
SageMaker Profiler. (default: ``None``).
1118+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1119+
The dict can contain 'EnableRemoteDebug'(bool).
1120+
For example,
1121+
1122+
.. code:: python
1123+
1124+
remote_debug_config = {
1125+
"EnableRemoteDebug": True,
1126+
}
11061127
environment (dict[str, str]) : Environment variables to be set for
11071128
use during training job (default: ``None``)
11081129
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -1206,6 +1227,9 @@ def _get_train_request( # noqa: C901
12061227
if profiler_config is not None:
12071228
train_request["ProfilerConfig"] = profiler_config
12081229

1230+
if remote_debug_config is not None:
1231+
train_request["RemoteDebugConfig"] = remote_debug_config
1232+
12091233
if retry_strategy is not None:
12101234
train_request["RetryStrategy"] = retry_strategy
12111235

@@ -1217,6 +1241,7 @@ def update_training_job(
12171241
profiler_rule_configs=None,
12181242
profiler_config=None,
12191243
resource_config=None,
1244+
remote_debug_config=None,
12201245
):
12211246
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
12221247
@@ -1228,6 +1253,15 @@ def update_training_job(
12281253
resource_config (dict): Configuration of the resources for the training job. You can
12291254
update the keep-alive period if the warm pool status is `Available`. No other fields
12301255
can be updated. (default: ``None``).
1256+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1257+
The dict can contain 'EnableRemoteDebug'(bool).
1258+
For example,
1259+
1260+
.. code:: python
1261+
1262+
remote_debug_config = {
1263+
"EnableRemoteDebug": True,
1264+
}
12311265
"""
12321266
# No injections from sagemaker_config because the UpdateTrainingJob API's resource_config
12331267
# object accepts fewer parameters than the CreateTrainingJob API, and none that the
@@ -1240,6 +1274,7 @@ def update_training_job(
12401274
profiler_rule_configs=profiler_rule_configs,
12411275
profiler_config=inferred_profiler_config,
12421276
resource_config=resource_config,
1277+
remote_debug_config=remote_debug_config,
12431278
)
12441279
LOGGER.info("Updating training job with name %s", job_name)
12451280
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
@@ -1251,6 +1286,7 @@ def _get_update_training_job_request(
12511286
profiler_rule_configs=None,
12521287
profiler_config=None,
12531288
resource_config=None,
1289+
remote_debug_config=None,
12541290
):
12551291
"""Constructs a request compatible for updating an Amazon SageMaker training job.
12561292
@@ -1262,6 +1298,15 @@ def _get_update_training_job_request(
12621298
resource_config (dict): Configuration of the resources for the training job. You can
12631299
update the keep-alive period if the warm pool status is `Available`. No other fields
12641300
can be updated. (default: ``None``).
1301+
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
1302+
The dict can contain 'EnableRemoteDebug'(bool).
1303+
For example,
1304+
1305+
.. code:: python
1306+
1307+
remote_debug_config = {
1308+
"EnableRemoteDebug": True,
1309+
}
12651310
12661311
Returns:
12671312
Dict: an update training request dict
@@ -1279,6 +1324,9 @@ def _get_update_training_job_request(
12791324
if resource_config is not None:
12801325
update_training_job_request["ResourceConfig"] = resource_config
12811326

1327+
if remote_debug_config is not None:
1328+
update_training_job_request["RemoteDebugConfig"] = remote_debug_config
1329+
12821330
return update_training_job_request
12831331

12841332
def process(

0 commit comments

Comments
 (0)