diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 3268d87a17..2365a21789 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -21,7 +21,7 @@ import tempfile from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor -from sagemaker import image_uris, utils +from sagemaker import image_uris, s3, utils class DataConfig: @@ -405,9 +405,15 @@ def _run( analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: json.dump(analysis_config, f) + s3_analysis_config_file = _upload_analysis_config( + analysis_config_file, + data_config.s3_output_path, + self.sagemaker_session, + kms_key, + ) config_input = ProcessingInput( input_name="analysis_config", - source=analysis_config_file, + source=s3_analysis_config_file, destination=self._CLARIFY_CONFIG_INPUT, s3_data_type="S3Prefix", s3_input_mode="File", @@ -638,6 +644,30 @@ def run_explainability( self._run(data_config, analysis_config, wait, logs, job_name, kms_key) +def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): + """Uploads the local analysis_config_file to the s3_output_path. + + Args: + analysis_config_file (str): File path to the local analysis config file. + s3_output_path (str): S3 prefix to store the analysis config file. + sagemaker_session (:class:`~sagemaker.session.Session`): + Session object which manages interactions with Amazon SageMaker and + any other AWS services needed. If not specified, the processor creates + one using the default AWS configuration chain. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). + + Returns: + The S3 uri of the uploaded file. + """ + return s3.S3Uploader.upload( + local_path=analysis_config_file, + desired_s3_uri=s3_output_path, + sagemaker_session=sagemaker_session, + kms_key=kms_key, + ) + + def _set(value, key, dictionary): """Sets dictionary[key] = value if value is not None.""" if value is not None: diff --git a/tests/integ/test_clarify.py b/tests/integ/test_clarify.py index ec72da622d..1b47b993c7 100644 --- a/tests/integ/test_clarify.py +++ b/tests/integ/test_clarify.py @@ -112,10 +112,11 @@ def clarify_processor(sagemaker_session, cpu_instance_type): return processor -@pytest.fixture(scope="module") +@pytest.fixture def data_config(sagemaker_session, data_path, headers): - output_path = "s3://{}/{}".format( - sagemaker_session.default_bucket(), "linear_learner_analysis_result" + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run ) return DataConfig( s3_data_input_path=data_path, @@ -195,6 +196,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag ) <= 1.0 ) + check_analysis_config(data_config, sagemaker_session, "pre_training_bias") def test_post_training_bias( @@ -227,6 +229,7 @@ def test_post_training_bias( ) <= 1.0 ) + check_analysis_config(data_config, sagemaker_session, "post_training_bias") def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session): @@ -250,3 +253,13 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak ) <= 1 ) + check_analysis_config(data_config, sagemaker_session, "shap") + + +def check_analysis_config(data_config, sagemaker_session, method): + analysis_config_json = s3.S3Downloader.read_file( + data_config.s3_output_path + "/analysis_config.json", + sagemaker_session, + ) + analysis_config = json.loads(analysis_config_json) + assert method in analysis_config["methods"]