Skip to content

Commit ebd1952

Browse files
Merge branch 'master' into fix-integ-tests
2 parents db92237 + 04824f7 commit ebd1952

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

src/sagemaker/clarify.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
self,
3232
s3_data_input_path,
3333
s3_output_path,
34+
s3_analysis_config_output_path=None,
3435
label=None,
3536
headers=None,
3637
features=None,
@@ -43,6 +44,9 @@ def __init__(
4344
Args:
4445
s3_data_input_path (str): Dataset S3 prefix/object URI.
4546
s3_output_path (str): S3 prefix to store the output.
47+
s3_analysis_config_output_path (str): S3 prefix to store the analysis_config output
48+
If this field is None, then the s3_output_path will be used
49+
to store the analysis_config output
4650
label (str): Target attribute of the model required by bias metrics (optional for SHAP)
4751
Specified as column name or index for CSV dataset, or as JSONPath for JSONLines.
4852
headers (list[str]): A list of column names in the input dataset.
@@ -61,6 +65,7 @@ def __init__(
6165
)
6266
self.s3_data_input_path = s3_data_input_path
6367
self.s3_output_path = s3_output_path
68+
self.s3_analysis_config_output_path = s3_analysis_config_output_path
6469
self.s3_data_distribution_type = s3_data_distribution_type
6570
self.s3_compression_type = s3_compression_type
6671
self.label = label
@@ -473,7 +478,7 @@ def _run(
473478
json.dump(analysis_config, f)
474479
s3_analysis_config_file = _upload_analysis_config(
475480
analysis_config_file,
476-
data_config.s3_output_path,
481+
data_config.s3_analysis_config_output_path or data_config.s3_output_path,
477482
self.sagemaker_session,
478483
kms_key,
479484
)

tests/unit/test_clarify.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
ModelPredictedLabelConfig,
2525
SHAPConfig,
2626
)
27-
from sagemaker import image_uris
27+
from sagemaker import image_uris, Processor
2828

2929
JOB_NAME_PREFIX = "my-prefix"
3030
TIMESTAMP = "2021-06-17-22-29-54-685"
@@ -499,6 +499,59 @@ def test_post_training_bias(
499499
)
500500

501501

502+
@patch.object(Processor, "run")
503+
def test_run_on_s3_analysis_config_file(
504+
processor_run, sagemaker_session, clarify_processor, data_config
505+
):
506+
analysis_config = {
507+
"methods": {"post_training_bias": {"methods": "all"}},
508+
}
509+
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:
510+
clarify_processor._run(
511+
data_config,
512+
analysis_config,
513+
True,
514+
True,
515+
"test",
516+
None,
517+
{"ExperimentName": "AnExperiment"},
518+
)
519+
analysis_config_file = mock_method.call_args[0][0]
520+
mock_method.assert_called_with(
521+
analysis_config_file, data_config.s3_output_path, sagemaker_session, None
522+
)
523+
524+
data_config_with_analysis_config_output = DataConfig(
525+
s3_data_input_path="s3://input/train.csv",
526+
s3_output_path="s3://output/analysis_test_result",
527+
s3_analysis_config_output_path="s3://analysis_config_output",
528+
label="Label",
529+
headers=[
530+
"Label",
531+
"F1",
532+
"F2",
533+
"F3",
534+
],
535+
dataset_type="text/csv",
536+
)
537+
clarify_processor._run(
538+
data_config_with_analysis_config_output,
539+
analysis_config,
540+
True,
541+
True,
542+
"test",
543+
None,
544+
{"ExperimentName": "AnExperiment"},
545+
)
546+
analysis_config_file = mock_method.call_args[0][0]
547+
mock_method.assert_called_with(
548+
analysis_config_file,
549+
data_config_with_analysis_config_output.s3_analysis_config_output_path,
550+
sagemaker_session,
551+
None,
552+
)
553+
554+
502555
def _run_test_shap(
503556
name_from_base,
504557
clarify_processor,

0 commit comments

Comments
 (0)