Skip to content

Commit 9c445f9

Browse files
author
Liu
committed
feature: add DisableProfiler in ProfilerConfig
1 parent 885423c commit 9c445f9

18 files changed

+67
-13
lines changed

src/sagemaker/debugger/profiler_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
3333
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
3434
framework_profile_params: Optional[FrameworkProfile] = None,
35+
disable_profiler: Optional[FrameworkProfile] = False,
3536
):
3637
"""Initialize a ``ProfilerConfig`` instance.
3738
@@ -78,6 +79,7 @@ class and SageMaker Framework estimators.
7879
self.s3_output_path = s3_output_path
7980
self.system_monitor_interval_millis = system_monitor_interval_millis
8081
self.framework_profile_params = framework_profile_params
82+
self.disable_profiler = disable_profiler
8183

8284
def _to_request_dict(self):
8385
"""Generate a request dictionary using the parameters provided when initializing the object.
@@ -91,6 +93,8 @@ def _to_request_dict(self):
9193
if self.s3_output_path is not None:
9294
profiler_config_request["S3OutputPath"] = self.s3_output_path
9395

96+
profiler_config_request["DisableProfiler"] = self.disable_profiler
97+
9498
if self.system_monitor_interval_millis is not None:
9599
profiler_config_request[
96100
"ProfilingIntervalInMilliseconds"

src/sagemaker/estimator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def _prepare_profiler_for_training(self):
943943
2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)]
944944
"""
945945
if self.disable_profiler:
946-
if self.profiler_config:
946+
if self.profiler_config and self.profiler_config.disable_profiler == False:
947947
raise RuntimeError("profiler_config cannot be set when disable_profiler is True.")
948948
if self.profiler_rules:
949949
raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.")
@@ -957,6 +957,12 @@ def _prepare_profiler_for_training(self):
957957
self.profiler_config.s3_output_path = self.output_path
958958

959959
self.profiler_rule_configs = self._prepare_profiler_rules()
960+
# if profiler_config is still None, it means the job has profiler disabled
961+
if self.profiler_config is None:
962+
# self.profiler_config = ProfilerConfig(disable_profiler=True)
963+
self.profiler_config = ProfilerConfig(
964+
s3_output_path=self.output_path, disable_profiler=True
965+
)
960966

961967
def _prepare_profiler_rules(self):
962968
"""Set any necessary values in profiler rules, if they are provided."""
@@ -1047,7 +1053,7 @@ def latest_job_profiler_artifacts_path(self):
10471053
error_message="""Cannot get the profiling output artifacts path.
10481054
The Estimator is not associated with a training job."""
10491055
)
1050-
if self.profiler_config is not None:
1056+
if self.profiler_config is not None and self.profiler_config.disable_profiler == False:
10511057
return os.path.join(
10521058
self.profiler_config.s3_output_path,
10531059
self.latest_training_job.name,

tests/integ/test_profiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule(
9393
)
9494

9595
job_description = mx.latest_training_job.describe()
96+
# Temporarily added until the service package changes are updated
97+
job_description["ProfilerConfig"]["DisableProfiler"] = False
9698
assert (
9799
job_description["ProfilerConfig"]
98100
== ProfilerConfig(
@@ -153,6 +155,8 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config(
153155
)
154156

155157
job_description = mx.latest_training_job.describe()
158+
# Temporarily added until the service package changes are updated
159+
job_description["ProfilerConfig"]["DisableProfiler"] = False
156160
assert job_description.get("ProfilerConfig") == profiler_config._to_request_dict()
157161
assert job_description.get("ProfilingStatus") == "Enabled"
158162

@@ -221,6 +225,8 @@ def test_mxnet_with_built_in_profiler_rule_with_custom_parameters(
221225
)
222226

223227
job_description = mx.latest_training_job.describe()
228+
# Temporarily added until the service package changes are updated
229+
job_description["ProfilerConfig"]["DisableProfiler"] = False
224230
assert job_description.get("ProfilingStatus") == "Enabled"
225231
assert (
226232
job_description.get("ProfilerConfig")
@@ -292,6 +298,8 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics(
292298
)
293299

294300
job_description = mx.latest_training_job.describe()
301+
# Temporarily added until the service package changes are updated
302+
job_description["ProfilerConfig"]["DisableProfiler"] = False
295303
assert job_description["ProfilerConfig"] == profiler_config._to_request_dict()
296304
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
297305
assert job_description.get("ProfilingStatus") == "Enabled"
@@ -423,13 +431,15 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling(
423431
)
424432

425433
job_description = mx.latest_training_job.describe()
426-
assert job_description.get("ProfilerConfig") is None
434+
# when the profiler is disabled, ProfilerConfig is not None. Temporarily remove this check until the service packages are updated.
435+
# assert job_description.get("ProfilerConfig") is None
427436
assert job_description.get("ProfilerRuleConfigurations") is None
428-
assert job_description.get("ProfilingStatus") == "Disabled"
437+
# Temporarily remove this check until the service packages are updated.
438+
# assert job_description.get("ProfilingStatus") == "Disabled"
429439

430440
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
431-
432-
mx.enable_default_profiling()
441+
# profilingStatus is currently wrong, temporarily remove this check until the service packages are updated.
442+
# mx.enable_default_profiling()
433443

434444
job_description = mx.latest_training_job.describe()
435445
assert job_description["ProfilerConfig"]["S3OutputPath"] == mx.output_path

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(version, base_framework_version):
150150
}
151151
],
152152
"profiler_config": {
153+
"DisableProfiler": False,
153154
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
154155
},
155156
}

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
143143
}
144144
],
145145
"profiler_config": {
146+
"DisableProfiler": False,
146147
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
147148
},
148149
}

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _create_train_job(
152152
}
153153
],
154154
"profiler_config": {
155+
"DisableProfiler": False,
155156
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
156157
},
157158
}

tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(
150150
}
151151
],
152152
"profiler_config": {
153+
"DisableProfiler": False,
153154
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
154155
},
155156
}

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config
152152
}
153153
],
154154
"profiler_config": {
155+
"DisableProfiler": False,
155156
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
156157
},
157158
}

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ def test_register_model_with_model_repack_with_estimator(
796796
"CollectionConfigurations": [],
797797
"S3OutputPath": f"s3://{BUCKET}/",
798798
},
799+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
799800
"HyperParameters": {
800801
"inference_script": '"dummy_script.py"',
801802
"dependencies": f'"{dummy_requirements}"',
@@ -923,6 +924,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
923924
"CollectionConfigurations": [],
924925
"S3OutputPath": f"s3://{BUCKET}/",
925926
},
927+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
926928
"HyperParameters": {
927929
"inference_script": '"dummy_script.py"',
928930
"model_archive": '"s3://my-bucket/model.tar.gz"',
@@ -1052,6 +1054,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
10521054
"CollectionConfigurations": [],
10531055
"S3OutputPath": f"s3://{BUCKET}/",
10541056
},
1057+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
10551058
"HyperParameters": {
10561059
"dependencies": "null",
10571060
"inference_script": '"dummy_script.py"',
@@ -1243,6 +1246,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator):
12431246
"TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/"
12441247
+ "sagemaker-scikit-learn:0.23-1-cpu-py3",
12451248
},
1249+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
12461250
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
12471251
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
12481252
"ResourceConfig": {

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def test_training_step_base_estimator(sagemaker_session):
374374
"CollectionConfigurations": [],
375375
},
376376
"ProfilerConfig": {
377+
"DisableProfiler": False,
377378
"ProfilingIntervalInMilliseconds": 500,
378379
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
379380
},
@@ -483,7 +484,7 @@ def test_training_step_tensorflow(sagemaker_session):
483484
"sagemaker_instance_type": {"Get": "Parameters.InstanceType"},
484485
"sagemaker_distributed_dataparallel_custom_mpi_options": '""',
485486
},
486-
"ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"},
487+
"ProfilerConfig": {"DisableProfiler": False, "S3OutputPath": "s3://my-bucket/"},
487488
},
488489
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
489490
}

tests/unit/sagemaker/workflow/test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def test_repack_model_step(estimator):
157157
}
158158
],
159159
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
160+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": f"s3://{BUCKET}/"},
160161
"ResourceConfig": {
161162
"InstanceCount": 1,
162163
"InstanceType": "ml.m5.large",
@@ -238,6 +239,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir):
238239
}
239240
],
240241
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
242+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": f"s3://{BUCKET}/"},
241243
"ResourceConfig": {
242244
"InstanceCount": 1,
243245
"InstanceType": "ml.m5.large",

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _create_train_job(version, py_version):
158158
}
159159
],
160160
"profiler_config": {
161+
"DisableProfiler": False,
161162
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
162163
},
163164
}

tests/unit/test_estimator.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
410410
],
411411
}
412412
assert args["profiler_config"] == {
413+
"DisableProfiler": False,
413414
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
414415
}
415416

@@ -574,6 +575,7 @@ def test_framework_without_debugger_and_profiler(time, sagemaker_session):
574575
}
575576
assert "debugger_rule_configs" not in args
576577
assert args["profiler_config"] == {
578+
"DisableProfiler": False,
577579
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
578580
}
579581
assert args["profiler_rule_configs"] == [
@@ -644,6 +646,7 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session):
644646
],
645647
}
646648
assert args["profiler_config"] == {
649+
"DisableProfiler": False,
647650
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
648651
}
649652
assert args["profiler_rule_configs"] == [
@@ -679,6 +682,7 @@ def test_framework_with_only_profiler_rule_specified(sagemaker_session):
679682
sagemaker_session.train.assert_called_once()
680683
_, args = sagemaker_session.train.call_args
681684
assert args["profiler_config"] == {
685+
"DisableProfiler": False,
682686
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
683687
}
684688
assert args["profiler_rule_configs"] == [
@@ -711,6 +715,7 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s
711715
sagemaker_session.train.assert_called_once()
712716
_, args = sagemaker_session.train.call_args
713717
assert args["profiler_config"] == {
718+
"DisableProfiler": False,
714719
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
715720
"ProfilingIntervalInMilliseconds": 1000,
716721
}
@@ -745,7 +750,9 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region):
745750
f.fit("s3://mydata")
746751
sms.train.assert_called_once()
747752
_, args = sms.train.call_args
748-
assert args.get("profiler_config") is None
753+
# assert args.get("profiler_config") == {"DisableProfiler": True}
754+
# temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service
755+
assert args.get("profiler_config")["DisableProfiler"] == True
749756
assert args.get("profiler_rule_configs") is None
750757

751758

@@ -927,6 +934,7 @@ def test_framework_with_enabling_default_profiling(
927934
sagemaker_session.update_training_job.assert_called_once()
928935
_, args = sagemaker_session.update_training_job.call_args
929936
assert args["profiler_config"] == {
937+
"DisableProfiler": False,
930938
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
931939
}
932940
assert args["profiler_rule_configs"] == [
@@ -960,6 +968,7 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path(
960968
sagemaker_session.update_training_job.assert_called_once()
961969
_, args = sagemaker_session.update_training_job.call_args
962970
assert args["profiler_config"] == {
971+
"DisableProfiler": False,
963972
"S3OutputPath": "s3://custom/",
964973
}
965974
assert args["profiler_rule_configs"] == [
@@ -1001,7 +1010,9 @@ def test_framework_with_disabling_profiling(sagemaker_session, training_job_desc
10011010
f.disable_profiling()
10021011
sagemaker_session.update_training_job.assert_called_once()
10031012
_, args = sagemaker_session.update_training_job.call_args
1004-
assert args["profiler_config"] == {"DisableProfiler": True}
1013+
# assert args["profiler_config"] == {"DisableProfiler": True}
1014+
# temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service
1015+
assert args.get("profiler_config")["DisableProfiler"] == True
10051016

10061017

10071018
def test_framework_with_update_profiler_when_no_training_job(sagemaker_session):
@@ -1058,6 +1069,7 @@ def test_framework_with_update_profiler_config(sagemaker_session):
10581069
sagemaker_session.update_training_job.assert_called_once()
10591070
_, args = sagemaker_session.update_training_job.call_args
10601071
assert args["profiler_config"] == {
1072+
"DisableProfiler": False,
10611073
"ProfilingIntervalInMilliseconds": 1000,
10621074
}
10631075
assert "profiler_rule_configs" not in args
@@ -1086,7 +1098,7 @@ def test_framework_with_update_profiler_report_rule(sagemaker_session):
10861098
"RuleParameters": {"rule_to_invoke": "ProfilerReport"},
10871099
}
10881100
]
1089-
assert "profiler_config" not in args
1101+
assert args["profiler_config"]["DisableProfiler"] == False
10901102

10911103

10921104
def test_framework_with_disable_framework_metrics(sagemaker_session):
@@ -1101,7 +1113,7 @@ def test_framework_with_disable_framework_metrics(sagemaker_session):
11011113
f.update_profiler(disable_framework_metrics=True)
11021114
sagemaker_session.update_training_job.assert_called_once()
11031115
_, args = sagemaker_session.update_training_job.call_args
1104-
assert args["profiler_config"] == {"ProfilingParameters": {}}
1116+
assert args["profiler_config"] == {"DisableProfiler": False, "ProfilingParameters": {}}
11051117
assert "profiler_rule_configs" not in args
11061118

11071119

@@ -1118,6 +1130,7 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage
11181130
sagemaker_session.update_training_job.assert_called_once()
11191131
_, args = sagemaker_session.update_training_job.call_args
11201132
assert args["profiler_config"] == {
1133+
"DisableProfiler": False,
11211134
"ProfilingIntervalInMilliseconds": 1000,
11221135
"ProfilingParameters": {},
11231136
}
@@ -1160,7 +1173,10 @@ def test_framework_with_update_profiler_config_and_profiler_rule(sagemaker_sessi
11601173
f.update_profiler(rules=[profiler_custom_rule], system_monitor_interval_millis=1000)
11611174
sagemaker_session.update_training_job.assert_called_once()
11621175
_, args = sagemaker_session.update_training_job.call_args
1163-
assert args["profiler_config"] == {"ProfilingIntervalInMilliseconds": 1000}
1176+
assert args["profiler_config"] == {
1177+
"DisableProfiler": False,
1178+
"ProfilingIntervalInMilliseconds": 1000,
1179+
}
11641180
assert args["profiler_rule_configs"] == [
11651181
{
11661182
"InstanceType": "c4.4xlarge",
@@ -2630,7 +2646,7 @@ def test_unsupported_type_in_dict():
26302646
"input_config": None,
26312647
"input_mode": "File",
26322648
"output_config": {"S3OutputPath": OUTPUT_PATH},
2633-
"profiler_config": {"S3OutputPath": OUTPUT_PATH},
2649+
"profiler_config": {"DisableProfiler": False, "S3OutputPath": OUTPUT_PATH},
26342650
"profiler_rule_configs": [
26352651
{
26362652
"RuleConfigurationName": "ProfilerReport-1510006209",

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _get_train_args(job_name):
167167
}
168168
],
169169
"profiler_config": {
170+
"DisableProfiler": False,
170171
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
171172
},
172173
}

tests/unit/test_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _create_train_job(version, py_version):
165165
}
166166
],
167167
"profiler_config": {
168+
"DisableProfiler": False,
168169
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
169170
},
170171
}

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
160160
}
161161
],
162162
"profiler_config": {
163+
"DisableProfiler": False,
163164
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
164165
},
165166
"retry_strategy": None,

tests/unit/test_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _create_train_job(version):
147147
}
148148
],
149149
"profiler_config": {
150+
"DisableProfiler": False,
150151
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
151152
},
152153
}

tests/unit/test_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
161161
}
162162
],
163163
"profiler_config": {
164+
"DisableProfiler": False,
164165
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
165166
},
166167
}

0 commit comments

Comments
 (0)