Skip to content

Change: Use the output path to store the Clarify config file #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile

from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
from sagemaker import image_uris, utils
from sagemaker import image_uris, s3, utils


class DataConfig:
Expand Down Expand Up @@ -405,9 +405,15 @@ def _run(
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
with open(analysis_config_file, "w") as f:
json.dump(analysis_config, f)
s3_analysis_config_file = _upload_analysis_config(
analysis_config_file,
data_config.s3_output_path,
self.sagemaker_session,
kms_key,
)
config_input = ProcessingInput(
input_name="analysis_config",
source=analysis_config_file,
source=s3_analysis_config_file,
destination=self._CLARIFY_CONFIG_INPUT,
s3_data_type="S3Prefix",
s3_input_mode="File",
Expand Down Expand Up @@ -638,6 +644,30 @@ def run_explainability(
self._run(data_config, analysis_config, wait, logs, job_name, kms_key)


def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
"""Uploads the local analysis_config_file to the s3_output_path.

Args:
analysis_config_file (str): File path to the local analysis config file.
s3_output_path (str): S3 prefix to store the analysis config file.
sagemaker_session (:class:`~sagemaker.session.Session`):
Session object which manages interactions with Amazon SageMaker and
any other AWS services needed. If not specified, the processor creates
one using the default AWS configuration chain.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
The S3 uri of the uploaded file.
"""
return s3.S3Uploader.upload(
local_path=analysis_config_file,
desired_s3_uri=s3_output_path,
sagemaker_session=sagemaker_session,
kms_key=kms_key,
)


def _set(value, key, dictionary):
"""Sets dictionary[key] = value if value is not None."""
if value is not None:
Expand Down
19 changes: 16 additions & 3 deletions tests/integ/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def clarify_processor(sagemaker_session, cpu_instance_type):
return processor


@pytest.fixture(scope="module")
@pytest.fixture
def data_config(sagemaker_session, data_path, headers):
output_path = "s3://{}/{}".format(
sagemaker_session.default_bucket(), "linear_learner_analysis_result"
test_run = utils.unique_name_from_base("test_run")
output_path = "s3://{}/{}/{}".format(
sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run
)
return DataConfig(
s3_data_input_path=data_path,
Expand Down Expand Up @@ -195,6 +196,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag
)
<= 1.0
)
check_analysis_config(data_config, sagemaker_session, "pre_training_bias")


def test_post_training_bias(
Expand Down Expand Up @@ -227,6 +229,7 @@ def test_post_training_bias(
)
<= 1.0
)
check_analysis_config(data_config, sagemaker_session, "post_training_bias")


def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session):
Expand All @@ -250,3 +253,13 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
)
<= 1
)
check_analysis_config(data_config, sagemaker_session, "shap")


def check_analysis_config(data_config, sagemaker_session, method):
analysis_config_json = s3.S3Downloader.read_file(
data_config.s3_output_path + "/analysis_config.json",
sagemaker_session,
)
analysis_config = json.loads(analysis_config_json)
assert method in analysis_config["methods"]