Skip to content

Commit e0b59ef

Browse files
Merge branch 'master' into stacicho-pipelines-local-add-notebook
2 parents c7c7900 + 0e9c10e commit e0b59ef

File tree

9 files changed

+743
-125
lines changed

9 files changed

+743
-125
lines changed

src/sagemaker/clarify.py

Lines changed: 347 additions & 103 deletions
Large diffs are not rendered by default.

src/sagemaker/estimator.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,29 @@ def _prepare_debugger_for_training(self):
823823
self.debugger_hook_config.s3_output_path = self.output_path
824824
self.debugger_rule_configs = self._prepare_debugger_rules()
825825
self._prepare_collection_configs()
826+
self._validate_and_set_debugger_configs()
827+
if not self.debugger_hook_config:
828+
if self.environment is None:
829+
self.environment = {}
830+
self.environment[DEBUGGER_FLAG] = "0"
831+
832+
def _validate_and_set_debugger_configs(self):
833+
"""Set defaults for debugging."""
834+
region_supports_debugger = _region_supports_debugger(
835+
self.sagemaker_session.boto_region_name
836+
)
837+
838+
if region_supports_debugger:
839+
if self.debugger_hook_config in [None, {}]:
840+
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
841+
else:
842+
if self.debugger_hook_config is not False and self.debugger_hook_config:
843+
# when user set debugger config in a unsupported region
844+
raise ValueError(
845+
"Current region does not support debugger but debugger hook config is set!"
846+
)
847+
# disable debugger in unsupported regions
848+
self.debugger_hook_config = False
826849

827850
def _prepare_debugger_rules(self):
828851
"""Set any necessary values in debugger rules, if they are provided."""
@@ -1766,6 +1789,8 @@ def enable_default_profiling(self):
17661789
Debugger monitoring is disabled.
17671790
"""
17681791
self._ensure_latest_training_job()
1792+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1793+
raise ValueError("Current region does not support profiler / debugger!")
17691794

17701795
training_job_details = self.latest_training_job.describe()
17711796

@@ -1799,6 +1824,8 @@ def disable_profiling(self):
17991824
18001825
"""
18011826
self._ensure_latest_training_job()
1827+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1828+
raise ValueError("Current region does not support profiler / debugger!")
18021829

18031830
training_job_details = self.latest_training_job.describe()
18041831

@@ -1852,6 +1879,8 @@ def update_profiler(
18521879
18531880
"""
18541881
self._ensure_latest_training_job()
1882+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1883+
raise ValueError("Current region does not support profiler / debugger!")
18551884

18561885
if (
18571886
not rules
@@ -2872,13 +2901,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:
28722901

28732902
def _validate_and_set_debugger_configs(self):
28742903
"""Set defaults for debugging."""
2875-
if self.debugger_hook_config is None and _region_supports_debugger(
2876-
self.sagemaker_session.boto_region_name
2877-
):
2878-
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
2879-
elif not self.debugger_hook_config:
2880-
# set hook config to False if _region_supports_debugger is False
2881-
self.debugger_hook_config = False
2904+
super(Framework, self)._validate_and_set_debugger_configs()
28822905

28832906
# Disable debugger if checkpointing is enabled by the customer
28842907
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
@@ -2901,11 +2924,6 @@ def _validate_and_set_debugger_configs(self):
29012924
)
29022925
self.debugger_hook_config = False
29032926

2904-
if self.debugger_hook_config is False:
2905-
if self.environment is None:
2906-
self.environment = {}
2907-
self.environment[DEBUGGER_FLAG] = "0"
2908-
29092927
def _model_source_dir(self):
29102928
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
29112929

src/sagemaker/fw_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,26 @@
5353
"only one worker per host regardless of the number of GPUs."
5454
)
5555

56-
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
57-
PROFILER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
56+
DEBUGGER_UNSUPPORTED_REGIONS = (
57+
"us-iso-east-1",
58+
"ap-southeast-3",
59+
"ap-southeast-4",
60+
"eu-south-2",
61+
"me-central-1",
62+
"ap-south-2",
63+
"eu-central-2",
64+
"us-gov-east-1",
65+
)
66+
PROFILER_UNSUPPORTED_REGIONS = (
67+
"us-iso-east-1",
68+
"ap-southeast-3",
69+
"ap-southeast-4",
70+
"eu-south-2",
71+
"me-central-1",
72+
"ap-south-2",
73+
"eu-central-2",
74+
"us-gov-east-1",
75+
)
5876

5977
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
6078
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (

tests/integ/test_clarify.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def data_path_no_label_index(training_set_no_label):
138138
def data_path_label_index(training_set_label_index):
139139
features, label, index = training_set_label_index
140140
data = pd.concat(
141-
[pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], axis=1, sort=False
141+
[pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)],
142+
axis=1,
143+
sort=False,
142144
)
143145
with tempfile.TemporaryDirectory() as tmpdirname:
144146
filename = os.path.join(tmpdirname, "train_label_index.csv")
@@ -151,7 +153,12 @@ def data_path_label_index(training_set_label_index):
151153
def data_path_label_index_6col(training_set_label_index):
152154
features, label, index = training_set_label_index
153155
data = pd.concat(
154-
[pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(features), pd.DataFrame(index)],
156+
[
157+
pd.DataFrame(label),
158+
pd.DataFrame(features),
159+
pd.DataFrame(features),
160+
pd.DataFrame(index),
161+
],
155162
axis=1,
156163
sort=False,
157164
)
@@ -551,7 +558,10 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag
551558

552559

553560
def test_pre_training_bias_facets_not_included(
554-
clarify_processor, data_config_facets_not_included, data_bias_config, sagemaker_session
561+
clarify_processor,
562+
data_config_facets_not_included,
563+
data_bias_config,
564+
sagemaker_session,
555565
):
556566
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
557567
clarify_processor.run_pre_training_bias(
@@ -643,7 +653,9 @@ def test_post_training_bias_facets_not_included_excluded_columns(
643653
<= 1.0
644654
)
645655
check_analysis_config(
646-
data_config_facets_not_included_multiple_files, sagemaker_session, "post_training_bias"
656+
data_config_facets_not_included_multiple_files,
657+
sagemaker_session,
658+
"post_training_bias",
647659
)
648660

649661

@@ -704,6 +716,50 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
704716
check_analysis_config(data_config, sagemaker_session, "shap")
705717

706718

719+
def test_bias_and_explainability(
720+
clarify_processor,
721+
data_config,
722+
model_config,
723+
shap_config,
724+
data_bias_config,
725+
sagemaker_session,
726+
):
727+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
728+
clarify_processor.run_bias_and_explainability(
729+
data_config,
730+
model_config,
731+
shap_config,
732+
data_bias_config,
733+
pre_training_methods="all",
734+
post_training_methods="all",
735+
model_predicted_label_config="score",
736+
job_name=utils.unique_name_from_base("clarify-bias-and-explainability"),
737+
wait=True,
738+
)
739+
analysis_result_json = s3.S3Downloader.read_file(
740+
data_config.s3_output_path + "/analysis.json",
741+
sagemaker_session,
742+
)
743+
analysis_result = json.loads(analysis_result_json)
744+
assert (
745+
math.fabs(
746+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
747+
)
748+
<= 1
749+
)
750+
check_analysis_config(data_config, sagemaker_session, "shap")
751+
752+
assert (
753+
math.fabs(
754+
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
755+
"value"
756+
]
757+
)
758+
<= 1.0
759+
)
760+
check_analysis_config(data_config, sagemaker_session, "post_training_bias")
761+
762+
707763
def check_analysis_config(data_config, sagemaker_session, method):
708764
analysis_config_json = s3.S3Downloader.read_file(
709765
data_config.s3_output_path + "/analysis_config.json",

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def test_fit_ps(time, strftime, sagemaker_session):
483483
expected_train_args = _create_train_job("1.11", ps=True, py_version="py2")
484484
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
485485
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True)
486+
expected_train_args["environment"] = {"USE_SMDEBUG": "0"}
486487

487488
actual_train_args = sagemaker_session.method_calls[0][2]
488489
assert actual_train_args == expected_train_args

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ def test_training_step_base_estimator(sagemaker_session):
370370
},
371371
"RoleArn": ROLE,
372372
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
373+
"DebugHookConfig": {
374+
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
375+
"CollectionConfigurations": [],
376+
},
373377
"ProfilerConfig": {
374378
"ProfilingIntervalInMilliseconds": 500,
375379
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},

0 commit comments

Comments
 (0)