Skip to content

Commit b334d4e

Browse files
authored
change: change s3UploadMode of sagemaker clarify processing output for computer vision jobs. (#3754)
1 parent 409546d commit b334d4e

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

src/sagemaker/clarify.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tempfile
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict, Optional, Any
29-
29+
from enum import Enum
3030
from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
3131

3232
from sagemaker import image_uris, s3, utils
@@ -304,6 +304,16 @@
304304
)
305305

306306

307+
class DatasetType(Enum):
308+
"""Enum to store different dataset types supported in the Analysis config file"""
309+
310+
TEXTCSV = "text/csv"
311+
JSONLINES = "application/jsonlines"
312+
JSON = "application/json"
313+
PARQUET = "application/x-parquet"
314+
IMAGE = "application/x-image"
315+
316+
307317
class DataConfig:
308318
"""Config object related to configurations of the input and output dataset."""
309319

@@ -1451,7 +1461,7 @@ def _run(
14511461
source=self._CLARIFY_OUTPUT,
14521462
destination=data_config.s3_output_path,
14531463
output_name="analysis_result",
1454-
s3_upload_mode="EndOfJob",
1464+
s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
14551465
)
14561466

14571467
return super().run(
@@ -2171,6 +2181,33 @@ def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_sess
21712181
)
21722182

21732183

2184+
class ProcessingOutputHandler:
2185+
"""Class to handle the parameters for SagemakerProcessor.Processingoutput"""
2186+
2187+
class S3UploadMode(Enum):
2188+
"""Enum values for different uplaod modes to s3 bucket"""
2189+
2190+
CONTINUOUS = "Continuous"
2191+
ENDOFJOB = "EndOfJob"
2192+
2193+
@classmethod
2194+
def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str:
2195+
"""Fetches s3_upload mode based on the shap_config values
2196+
2197+
Args:
2198+
analysis_config (dict): dict Config following the analysis_config.json format
2199+
2200+
Returns:
2201+
The s3_upload_mode type for the processing output.
2202+
"""
2203+
dataset_type = analysis_config["dataset_type"]
2204+
return (
2205+
ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
2206+
if dataset_type == DatasetType.IMAGE.value
2207+
else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value
2208+
)
2209+
2210+
21742211
def _set(value, key, dictionary):
21752212
"""Sets dictionary[key] = value if value is not None."""
21762213
if value is not None:

src/sagemaker/workflow/clarify_check_step.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ModelConfig,
3030
ModelPredictedLabelConfig,
3131
SHAPConfig,
32+
ProcessingOutputHandler,
3233
_upload_analysis_config,
3334
SageMakerClarifyProcessor,
3435
_set,
@@ -391,7 +392,7 @@ def _generate_processing_job_parameters(
391392
source=SageMakerClarifyProcessor._CLARIFY_OUTPUT,
392393
destination=data_config.s3_output_path,
393394
output_name="analysis_result",
394-
s3_upload_mode="EndOfJob",
395+
s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
395396
)
396397
return dict(config_input=config_input, data_input=data_input, result_output=result_output)
397398

tests/unit/test_clarify.py

+14
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
TextConfig,
3131
ImageConfig,
3232
_AnalysisConfigGenerator,
33+
DatasetType,
34+
ProcessingOutputHandler,
3335
)
3436

3537
JOB_NAME_PREFIX = "my-prefix"
@@ -1786,3 +1788,15 @@ def test_invalid_analysis_config(data_config, data_bias_config, model_config):
17861788
pre_training_methods="all",
17871789
post_training_methods="all",
17881790
)
1791+
1792+
1793+
class TestProcessingOutputHandler:
1794+
def test_get_s3_upload_mode_image(self):
1795+
analysis_config = {"dataset_type": DatasetType.IMAGE.value}
1796+
s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config)
1797+
assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
1798+
1799+
def test_get_s3_upload_mode_text(self):
1800+
analysis_config = {"dataset_type": DatasetType.TEXTCSV.value}
1801+
s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config)
1802+
assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value

0 commit comments

Comments
 (0)