diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 706e82aa8a..2f7f3dc53d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -26,7 +26,7 @@ import tempfile from abc import ABC, abstractmethod from typing import List, Union, Dict, Optional, Any - +from enum import Enum from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex from sagemaker import image_uris, s3, utils @@ -304,6 +304,16 @@ ) +class DatasetType(Enum): + """Enum to store different dataset types supported in the Analysis config file""" + + TEXTCSV = "text/csv" + JSONLINES = "application/jsonlines" + JSON = "application/json" + PARQUET = "application/x-parquet" + IMAGE = "application/x-image" + + class DataConfig: """Config object related to configurations of the input and output dataset.""" @@ -1451,7 +1461,7 @@ def _run( source=self._CLARIFY_OUTPUT, destination=data_config.s3_output_path, output_name="analysis_result", - s3_upload_mode="EndOfJob", + s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config), ) return super().run( @@ -2171,6 +2181,33 @@ def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_sess ) +class ProcessingOutputHandler: + """Class to handle the parameters for SagemakerProcessor.Processingoutput""" + + class S3UploadMode(Enum): + """Enum values for different uplaod modes to s3 bucket""" + + CONTINUOUS = "Continuous" + ENDOFJOB = "EndOfJob" + + @classmethod + def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str: + """Fetches s3_upload mode based on the shap_config values + + Args: + analysis_config (dict): dict Config following the analysis_config.json format + + Returns: + The s3_upload_mode type for the processing output. + """ + dataset_type = analysis_config["dataset_type"] + return ( + ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value + if dataset_type == DatasetType.IMAGE.value + else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value + ) + + def _set(value, key, dictionary): """Sets dictionary[key] = value if value is not None.""" if value is not None: diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 22b6fc2051..32793de977 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -29,6 +29,7 @@ ModelConfig, ModelPredictedLabelConfig, SHAPConfig, + ProcessingOutputHandler, _upload_analysis_config, SageMakerClarifyProcessor, _set, @@ -391,7 +392,7 @@ def _generate_processing_job_parameters( source=SageMakerClarifyProcessor._CLARIFY_OUTPUT, destination=data_config.s3_output_path, output_name="analysis_result", - s3_upload_mode="EndOfJob", + s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config), ) return dict(config_input=config_input, data_input=data_input, result_output=result_output) diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0c80cfe004..714f9d316c 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -30,6 +30,8 @@ TextConfig, ImageConfig, _AnalysisConfigGenerator, + DatasetType, + ProcessingOutputHandler, ) JOB_NAME_PREFIX = "my-prefix" @@ -1786,3 +1788,15 @@ def test_invalid_analysis_config(data_config, data_bias_config, model_config): pre_training_methods="all", post_training_methods="all", ) + + +class TestProcessingOutputHandler: + def test_get_s3_upload_mode_image(self): + analysis_config = {"dataset_type": DatasetType.IMAGE.value} + s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config) + assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value + + def test_get_s3_upload_mode_text(self): + analysis_config = {"dataset_type": DatasetType.TEXTCSV.value} + s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config) + assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value