Skip to content

Commit ce04fed

Browse files
Liuclaytonparnell
Liu
authored andcommitted
feature: add DisableProfiler in ProfilerConfig
1 parent f2d5e41 commit ce04fed

18 files changed

+67
-13
lines changed

src/sagemaker/debugger/profiler_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
3333
system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None,
3434
framework_profile_params: Optional[FrameworkProfile] = None,
35+
disable_profiler: Optional[FrameworkProfile] = False,
3536
):
3637
"""Initialize a ``ProfilerConfig`` instance.
3738
@@ -78,6 +79,7 @@ class and SageMaker Framework estimators.
7879
self.s3_output_path = s3_output_path
7980
self.system_monitor_interval_millis = system_monitor_interval_millis
8081
self.framework_profile_params = framework_profile_params
82+
self.disable_profiler = disable_profiler
8183

8284
def _to_request_dict(self):
8385
"""Generate a request dictionary using the parameters provided when initializing the object.
@@ -91,6 +93,8 @@ def _to_request_dict(self):
9193
if self.s3_output_path is not None:
9294
profiler_config_request["S3OutputPath"] = self.s3_output_path
9395

96+
profiler_config_request["DisableProfiler"] = self.disable_profiler
97+
9498
if self.system_monitor_interval_millis is not None:
9599
profiler_config_request[
96100
"ProfilingIntervalInMilliseconds"

src/sagemaker/estimator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def _prepare_profiler_for_training(self):
920920
2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)]
921921
"""
922922
if self.disable_profiler:
923-
if self.profiler_config:
923+
if self.profiler_config and self.profiler_config.disable_profiler == False:
924924
raise RuntimeError("profiler_config cannot be set when disable_profiler is True.")
925925
if self.profiler_rules:
926926
raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.")
@@ -934,6 +934,12 @@ def _prepare_profiler_for_training(self):
934934
self.profiler_config.s3_output_path = self.output_path
935935

936936
self.profiler_rule_configs = self._prepare_profiler_rules()
937+
# if profiler_config is still None, it means the job has profiler disabled
938+
if self.profiler_config is None:
939+
# self.profiler_config = ProfilerConfig(disable_profiler=True)
940+
self.profiler_config = ProfilerConfig(
941+
s3_output_path=self.output_path, disable_profiler=True
942+
)
937943

938944
def _prepare_profiler_rules(self):
939945
"""Set any necessary values in profiler rules, if they are provided."""
@@ -1024,7 +1030,7 @@ def latest_job_profiler_artifacts_path(self):
10241030
error_message="""Cannot get the profiling output artifacts path.
10251031
The Estimator is not associated with a training job."""
10261032
)
1027-
if self.profiler_config is not None:
1033+
if self.profiler_config is not None and self.profiler_config.disable_profiler == False:
10281034
return os.path.join(
10291035
self.profiler_config.s3_output_path,
10301036
self.latest_training_job.name,

tests/integ/test_profiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule(
9393
)
9494

9595
job_description = mx.latest_training_job.describe()
96+
# Temporarily added until the service package changes are updated
97+
job_description["ProfilerConfig"]["DisableProfiler"] = False
9698
assert (
9799
job_description["ProfilerConfig"]
98100
== ProfilerConfig(
@@ -153,6 +155,8 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config(
153155
)
154156

155157
job_description = mx.latest_training_job.describe()
158+
# Temporarily added until the service package changes are updated
159+
job_description["ProfilerConfig"]["DisableProfiler"] = False
156160
assert job_description.get("ProfilerConfig") == profiler_config._to_request_dict()
157161
assert job_description.get("ProfilingStatus") == "Enabled"
158162

@@ -221,6 +225,8 @@ def test_mxnet_with_built_in_profiler_rule_with_custom_parameters(
221225
)
222226

223227
job_description = mx.latest_training_job.describe()
228+
# Temporarily added until the service package changes are updated
229+
job_description["ProfilerConfig"]["DisableProfiler"] = False
224230
assert job_description.get("ProfilingStatus") == "Enabled"
225231
assert (
226232
job_description.get("ProfilerConfig")
@@ -292,6 +298,8 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics(
292298
)
293299

294300
job_description = mx.latest_training_job.describe()
301+
# Temporarily added until the service package changes are updated
302+
job_description["ProfilerConfig"]["DisableProfiler"] = False
295303
assert job_description["ProfilerConfig"] == profiler_config._to_request_dict()
296304
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
297305
assert job_description.get("ProfilingStatus") == "Enabled"
@@ -423,13 +431,15 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling(
423431
)
424432

425433
job_description = mx.latest_training_job.describe()
426-
assert job_description.get("ProfilerConfig") is None
434+
# when the profiler is disabled, ProfilerConfig is not None. Temporarily remove this check until the service packages are updated.
435+
# assert job_description.get("ProfilerConfig") is None
427436
assert job_description.get("ProfilerRuleConfigurations") is None
428-
assert job_description.get("ProfilingStatus") == "Disabled"
437+
# Temporarily remove this check until the service packages are updated.
438+
# assert job_description.get("ProfilingStatus") == "Disabled"
429439

430440
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
431-
432-
mx.enable_default_profiling()
441+
# profilingStatus is currently wrong, temporarily remove this check until the service packages are updated.
442+
# mx.enable_default_profiling()
433443

434444
job_description = mx.latest_training_job.describe()
435445
assert job_description["ProfilerConfig"]["S3OutputPath"] == mx.output_path

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(version, base_framework_version):
150150
}
151151
],
152152
"profiler_config": {
153+
"DisableProfiler": False,
153154
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
154155
},
155156
}

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
143143
}
144144
],
145145
"profiler_config": {
146+
"DisableProfiler": False,
146147
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
147148
},
148149
}

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _create_train_job(
152152
}
153153
],
154154
"profiler_config": {
155+
"DisableProfiler": False,
155156
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
156157
},
157158
}

tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(
150150
}
151151
],
152152
"profiler_config": {
153+
"DisableProfiler": False,
153154
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
154155
},
155156
}

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config
152152
}
153153
],
154154
"profiler_config": {
155+
"DisableProfiler": False,
155156
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
156157
},
157158
}

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ def test_register_model_with_model_repack_with_estimator(
795795
"CollectionConfigurations": [],
796796
"S3OutputPath": f"s3://{BUCKET}/",
797797
},
798+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
798799
"HyperParameters": {
799800
"inference_script": '"dummy_script.py"',
800801
"dependencies": f'"{dummy_requirements}"',
@@ -922,6 +923,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
922923
"CollectionConfigurations": [],
923924
"S3OutputPath": f"s3://{BUCKET}/",
924925
},
926+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
925927
"HyperParameters": {
926928
"inference_script": '"dummy_script.py"',
927929
"model_archive": '"s3://my-bucket/model.tar.gz"',
@@ -1051,6 +1053,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
10511053
"CollectionConfigurations": [],
10521054
"S3OutputPath": f"s3://{BUCKET}/",
10531055
},
1056+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
10541057
"HyperParameters": {
10551058
"dependencies": "null",
10561059
"inference_script": '"dummy_script.py"',
@@ -1242,6 +1245,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator):
12421245
"TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/"
12431246
+ "sagemaker-scikit-learn:0.23-1-cpu-py3",
12441247
},
1248+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": "s3://my-bucket/"},
12451249
"OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
12461250
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
12471251
"ResourceConfig": {

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def test_training_step_base_estimator(sagemaker_session):
375375
"CollectionConfigurations": [],
376376
},
377377
"ProfilerConfig": {
378+
"DisableProfiler": False,
378379
"ProfilingIntervalInMilliseconds": 500,
379380
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
380381
},
@@ -484,7 +485,7 @@ def test_training_step_tensorflow(sagemaker_session):
484485
"sagemaker_instance_type": {"Get": "Parameters.InstanceType"},
485486
"sagemaker_distributed_dataparallel_custom_mpi_options": '""',
486487
},
487-
"ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"},
488+
"ProfilerConfig": {"DisableProfiler": False, "S3OutputPath": "s3://my-bucket/"},
488489
"Environment": {DEBUGGER_FLAG: "0"},
489490
},
490491
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},

tests/unit/sagemaker/workflow/test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_repack_model_step(estimator):
147147
}
148148
],
149149
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
150+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": f"s3://{BUCKET}/"},
150151
"ResourceConfig": {
151152
"InstanceCount": 1,
152153
"InstanceType": "ml.m5.large",
@@ -225,6 +226,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir):
225226
}
226227
],
227228
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
229+
"ProfilerConfig": {"DisableProfiler": True, "S3OutputPath": f"s3://{BUCKET}/"},
228230
"ResourceConfig": {
229231
"InstanceCount": 1,
230232
"InstanceType": "ml.m5.large",

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _create_train_job(version, py_version):
158158
}
159159
],
160160
"profiler_config": {
161+
"DisableProfiler": False,
161162
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
162163
},
163164
}

tests/unit/test_estimator.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
406406
],
407407
}
408408
assert args["profiler_config"] == {
409+
"DisableProfiler": False,
409410
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
410411
}
411412

@@ -570,6 +571,7 @@ def test_framework_without_debugger_and_profiler(time, sagemaker_session):
570571
}
571572
assert "debugger_rule_configs" not in args
572573
assert args["profiler_config"] == {
574+
"DisableProfiler": False,
573575
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
574576
}
575577
assert args["profiler_rule_configs"] == [
@@ -640,6 +642,7 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session):
640642
],
641643
}
642644
assert args["profiler_config"] == {
645+
"DisableProfiler": False,
643646
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
644647
}
645648
assert args["profiler_rule_configs"] == [
@@ -675,6 +678,7 @@ def test_framework_with_only_profiler_rule_specified(sagemaker_session):
675678
sagemaker_session.train.assert_called_once()
676679
_, args = sagemaker_session.train.call_args
677680
assert args["profiler_config"] == {
681+
"DisableProfiler": False,
678682
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
679683
}
680684
assert args["profiler_rule_configs"] == [
@@ -707,6 +711,7 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s
707711
sagemaker_session.train.assert_called_once()
708712
_, args = sagemaker_session.train.call_args
709713
assert args["profiler_config"] == {
714+
"DisableProfiler": False,
710715
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
711716
"ProfilingIntervalInMilliseconds": 1000,
712717
}
@@ -741,7 +746,9 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region):
741746
f.fit("s3://mydata")
742747
sms.train.assert_called_once()
743748
_, args = sms.train.call_args
744-
assert args.get("profiler_config") is None
749+
# assert args.get("profiler_config") == {"DisableProfiler": True}
750+
# temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service
751+
assert args.get("profiler_config")["DisableProfiler"] == True
745752
assert args.get("profiler_rule_configs") is None
746753

747754

@@ -923,6 +930,7 @@ def test_framework_with_enabling_default_profiling(
923930
sagemaker_session.update_training_job.assert_called_once()
924931
_, args = sagemaker_session.update_training_job.call_args
925932
assert args["profiler_config"] == {
933+
"DisableProfiler": False,
926934
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
927935
}
928936
assert args["profiler_rule_configs"] == [
@@ -956,6 +964,7 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path(
956964
sagemaker_session.update_training_job.assert_called_once()
957965
_, args = sagemaker_session.update_training_job.call_args
958966
assert args["profiler_config"] == {
967+
"DisableProfiler": False,
959968
"S3OutputPath": "s3://custom/",
960969
}
961970
assert args["profiler_rule_configs"] == [
@@ -997,7 +1006,9 @@ def test_framework_with_disabling_profiling(sagemaker_session, training_job_desc
9971006
f.disable_profiling()
9981007
sagemaker_session.update_training_job.assert_called_once()
9991008
_, args = sagemaker_session.update_training_job.call_args
1000-
assert args["profiler_config"] == {"DisableProfiler": True}
1009+
# assert args["profiler_config"] == {"DisableProfiler": True}
1010+
# temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service
1011+
assert args.get("profiler_config")["DisableProfiler"] == True
10011012

10021013

10031014
def test_framework_with_update_profiler_when_no_training_job(sagemaker_session):
@@ -1054,6 +1065,7 @@ def test_framework_with_update_profiler_config(sagemaker_session):
10541065
sagemaker_session.update_training_job.assert_called_once()
10551066
_, args = sagemaker_session.update_training_job.call_args
10561067
assert args["profiler_config"] == {
1068+
"DisableProfiler": False,
10571069
"ProfilingIntervalInMilliseconds": 1000,
10581070
}
10591071
assert "profiler_rule_configs" not in args
@@ -1082,7 +1094,7 @@ def test_framework_with_update_profiler_report_rule(sagemaker_session):
10821094
"RuleParameters": {"rule_to_invoke": "ProfilerReport"},
10831095
}
10841096
]
1085-
assert "profiler_config" not in args
1097+
assert args["profiler_config"]["DisableProfiler"] == False
10861098

10871099

10881100
def test_framework_with_disable_framework_metrics(sagemaker_session):
@@ -1097,7 +1109,7 @@ def test_framework_with_disable_framework_metrics(sagemaker_session):
10971109
f.update_profiler(disable_framework_metrics=True)
10981110
sagemaker_session.update_training_job.assert_called_once()
10991111
_, args = sagemaker_session.update_training_job.call_args
1100-
assert args["profiler_config"] == {"ProfilingParameters": {}}
1112+
assert args["profiler_config"] == {"DisableProfiler": False, "ProfilingParameters": {}}
11011113
assert "profiler_rule_configs" not in args
11021114

11031115

@@ -1114,6 +1126,7 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage
11141126
sagemaker_session.update_training_job.assert_called_once()
11151127
_, args = sagemaker_session.update_training_job.call_args
11161128
assert args["profiler_config"] == {
1129+
"DisableProfiler": False,
11171130
"ProfilingIntervalInMilliseconds": 1000,
11181131
"ProfilingParameters": {},
11191132
}
@@ -1156,7 +1169,10 @@ def test_framework_with_update_profiler_config_and_profiler_rule(sagemaker_sessi
11561169
f.update_profiler(rules=[profiler_custom_rule], system_monitor_interval_millis=1000)
11571170
sagemaker_session.update_training_job.assert_called_once()
11581171
_, args = sagemaker_session.update_training_job.call_args
1159-
assert args["profiler_config"] == {"ProfilingIntervalInMilliseconds": 1000}
1172+
assert args["profiler_config"] == {
1173+
"DisableProfiler": False,
1174+
"ProfilingIntervalInMilliseconds": 1000,
1175+
}
11601176
assert args["profiler_rule_configs"] == [
11611177
{
11621178
"InstanceType": "c4.4xlarge",
@@ -2626,7 +2642,7 @@ def test_unsupported_type_in_dict():
26262642
"input_config": None,
26272643
"input_mode": "File",
26282644
"output_config": {"S3OutputPath": OUTPUT_PATH},
2629-
"profiler_config": {"S3OutputPath": OUTPUT_PATH},
2645+
"profiler_config": {"DisableProfiler": False, "S3OutputPath": OUTPUT_PATH},
26302646
"profiler_rule_configs": [
26312647
{
26322648
"RuleConfigurationName": "ProfilerReport-1510006209",

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _get_train_args(job_name):
167167
}
168168
],
169169
"profiler_config": {
170+
"DisableProfiler": False,
170171
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
171172
},
172173
}

tests/unit/test_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _create_train_job(version, py_version):
165165
}
166166
],
167167
"profiler_config": {
168+
"DisableProfiler": False,
168169
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
169170
},
170171
}

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
160160
}
161161
],
162162
"profiler_config": {
163+
"DisableProfiler": False,
163164
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
164165
},
165166
"retry_strategy": None,

tests/unit/test_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _create_train_job(version):
147147
}
148148
],
149149
"profiler_config": {
150+
"DisableProfiler": False,
150151
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
151152
},
152153
}

tests/unit/test_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
161161
}
162162
],
163163
"profiler_config": {
164+
"DisableProfiler": False,
164165
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
165166
},
166167
}

0 commit comments

Comments
 (0)