Skip to content

Commit 5348eda

Browse files
authored
change: Use the output path to store the Clarify config file (#2138)
1 parent 0da3339 commit 5348eda

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

src/sagemaker/clarify.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tempfile
2222

2323
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
24-
from sagemaker import image_uris, utils
24+
from sagemaker import image_uris, s3, utils
2525

2626

2727
class DataConfig:
@@ -405,9 +405,15 @@ def _run(
405405
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
406406
with open(analysis_config_file, "w") as f:
407407
json.dump(analysis_config, f)
408+
s3_analysis_config_file = _upload_analysis_config(
409+
analysis_config_file,
410+
data_config.s3_output_path,
411+
self.sagemaker_session,
412+
kms_key,
413+
)
408414
config_input = ProcessingInput(
409415
input_name="analysis_config",
410-
source=analysis_config_file,
416+
source=s3_analysis_config_file,
411417
destination=self._CLARIFY_CONFIG_INPUT,
412418
s3_data_type="S3Prefix",
413419
s3_input_mode="File",
@@ -638,6 +644,30 @@ def run_explainability(
638644
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
639645

640646

647+
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
648+
"""Uploads the local analysis_config_file to the s3_output_path.
649+
650+
Args:
651+
analysis_config_file (str): File path to the local analysis config file.
652+
s3_output_path (str): S3 prefix to store the analysis config file.
653+
sagemaker_session (:class:`~sagemaker.session.Session`):
654+
Session object which manages interactions with Amazon SageMaker and
655+
any other AWS services needed. If not specified, the processor creates
656+
one using the default AWS configuration chain.
657+
kms_key (str): The ARN of the KMS key that is used to encrypt the
658+
user code file (default: None).
659+
660+
Returns:
661+
The S3 uri of the uploaded file.
662+
"""
663+
return s3.S3Uploader.upload(
664+
local_path=analysis_config_file,
665+
desired_s3_uri=s3_output_path,
666+
sagemaker_session=sagemaker_session,
667+
kms_key=kms_key,
668+
)
669+
670+
641671
def _set(value, key, dictionary):
642672
"""Sets dictionary[key] = value if value is not None."""
643673
if value is not None:

tests/integ/test_clarify.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ def clarify_processor(sagemaker_session, cpu_instance_type):
112112
return processor
113113

114114

115-
@pytest.fixture(scope="module")
115+
@pytest.fixture
116116
def data_config(sagemaker_session, data_path, headers):
117-
output_path = "s3://{}/{}".format(
118-
sagemaker_session.default_bucket(), "linear_learner_analysis_result"
117+
test_run = utils.unique_name_from_base("test_run")
118+
output_path = "s3://{}/{}/{}".format(
119+
sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run
119120
)
120121
return DataConfig(
121122
s3_data_input_path=data_path,
@@ -195,6 +196,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag
195196
)
196197
<= 1.0
197198
)
199+
check_analysis_config(data_config, sagemaker_session, "pre_training_bias")
198200

199201

200202
def test_post_training_bias(
@@ -227,6 +229,7 @@ def test_post_training_bias(
227229
)
228230
<= 1.0
229231
)
232+
check_analysis_config(data_config, sagemaker_session, "post_training_bias")
230233

231234

232235
def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session):
@@ -250,3 +253,13 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
250253
)
251254
<= 1
252255
)
256+
check_analysis_config(data_config, sagemaker_session, "shap")
257+
258+
259+
def check_analysis_config(data_config, sagemaker_session, method):
260+
analysis_config_json = s3.S3Downloader.read_file(
261+
data_config.s3_output_path + "/analysis_config.json",
262+
sagemaker_session,
263+
)
264+
analysis_config = json.loads(analysis_config_json)
265+
assert method in analysis_config["methods"]

0 commit comments

Comments
 (0)