|
21 | 21 | import tempfile
|
22 | 22 |
|
23 | 23 | from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
|
24 |
| -from sagemaker import image_uris, utils |
| 24 | +from sagemaker import image_uris, s3, utils |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class DataConfig:
|
@@ -405,9 +405,15 @@ def _run(
|
405 | 405 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
|
406 | 406 | with open(analysis_config_file, "w") as f:
|
407 | 407 | json.dump(analysis_config, f)
|
| 408 | + s3_analysis_config_file = _upload_analysis_config( |
| 409 | + analysis_config_file, |
| 410 | + data_config.s3_output_path, |
| 411 | + self.sagemaker_session, |
| 412 | + kms_key, |
| 413 | + ) |
408 | 414 | config_input = ProcessingInput(
|
409 | 415 | input_name="analysis_config",
|
410 |
| - source=analysis_config_file, |
| 416 | + source=s3_analysis_config_file, |
411 | 417 | destination=self._CLARIFY_CONFIG_INPUT,
|
412 | 418 | s3_data_type="S3Prefix",
|
413 | 419 | s3_input_mode="File",
|
@@ -638,6 +644,30 @@ def run_explainability(
|
638 | 644 | self._run(data_config, analysis_config, wait, logs, job_name, kms_key)
|
639 | 645 |
|
640 | 646 |
|
| 647 | +def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): |
| 648 | + """Uploads the local analysis_config_file to the s3_output_path. |
| 649 | +
|
| 650 | + Args: |
| 651 | + analysis_config_file (str): File path to the local analysis config file. |
| 652 | + s3_output_path (str): S3 prefix to store the analysis config file. |
| 653 | + sagemaker_session (:class:`~sagemaker.session.Session`): |
| 654 | + Session object which manages interactions with Amazon SageMaker and |
| 655 | + any other AWS services needed. If not specified, the processor creates |
| 656 | + one using the default AWS configuration chain. |
| 657 | + kms_key (str): The ARN of the KMS key that is used to encrypt the |
| 658 | + user code file (default: None). |
| 659 | +
|
| 660 | + Returns: |
| 661 | + The S3 uri of the uploaded file. |
| 662 | + """ |
| 663 | + return s3.S3Uploader.upload( |
| 664 | + local_path=analysis_config_file, |
| 665 | + desired_s3_uri=s3_output_path, |
| 666 | + sagemaker_session=sagemaker_session, |
| 667 | + kms_key=kms_key, |
| 668 | + ) |
| 669 | + |
| 670 | + |
641 | 671 | def _set(value, key, dictionary):
|
642 | 672 | """Sets dictionary[key] = value if value is not None."""
|
643 | 673 | if value is not None:
|
|
0 commit comments