Skip to content

Commit 8a5fa72

Browse files
committed
change: s3_upload mode for CV jobs in clarify processing output
1 parent 9946d67 commit 8a5fa72

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/sagemaker/clarify.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tempfile
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict, Optional, Any
29-
29+
from enum import Enum
3030
from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
3131

3232
from sagemaker import image_uris, s3, utils
@@ -302,6 +302,16 @@
302302
)
303303

304304

305+
class DatasetType(Enum):
306+
"""Enum to store different dataset types supported in the Analysis config file"""
307+
308+
TEXTCSV = "text/csv"
309+
JSONLINES = "application/jsonlines"
310+
JSON = "application/json"
311+
PARQUET = "application/x-parquet"
312+
IMAGE = "application/x-image"
313+
314+
305315
class DataConfig:
306316
"""Config object related to configurations of the input and output dataset."""
307317

@@ -1363,7 +1373,7 @@ def _run(
13631373
source=self._CLARIFY_OUTPUT,
13641374
destination=data_config.s3_output_path,
13651375
output_name="analysis_result",
1366-
s3_upload_mode="EndOfJob",
1376+
s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
13671377
)
13681378

13691379
return super().run(
@@ -2083,6 +2093,30 @@ def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_sess
20832093
)
20842094

20852095

2096+
class ProcessingOutputHandler:
2097+
"""Handles the parameters sent in SagemakerProcessor.Processingoutput based on the dataset
2098+
type in analysis_config.
2099+
"""
2100+
2101+
class S3UploadMode(Enum):
2102+
"""Enum values for different uplaod modes to s3 bucket"""
2103+
2104+
CONTINUOUS = "Continuous"
2105+
ENDOFJOB = "EndOfJob"
2106+
2107+
@classmethod
2108+
def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str:
2109+
"""
2110+
returns the s3_upload mode based on the shap_config values
2111+
"""
2112+
dataset_type = analysis_config["dataset_type"]
2113+
return (
2114+
ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
2115+
if dataset_type == DatasetType.IMAGE
2116+
else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value
2117+
)
2118+
2119+
20862120
def _set(value, key, dictionary):
20872121
"""Sets dictionary[key] = value if value is not None."""
20882122
if value is not None:

tests/unit/test_clarify.py

+14
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
TextConfig,
3131
ImageConfig,
3232
_AnalysisConfigGenerator,
33+
DatasetType,
34+
ProcessingOutputHandler,
3335
)
3436

3537
JOB_NAME_PREFIX = "my-prefix"
@@ -1714,3 +1716,15 @@ def test_invalid_analysis_config(data_config, data_bias_config, model_config):
17141716
pre_training_methods="all",
17151717
post_training_methods="all",
17161718
)
1719+
1720+
1721+
class TestProcessingOutputHandler:
1722+
def test_get_s3_upload_mode_image(self):
1723+
analysis_config = {"dataset_type": DatasetType.IMAGE}
1724+
s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config)
1725+
assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
1726+
1727+
def test_get_s3_upload_mode_text(self):
1728+
analysis_config = {"dataset_type": DatasetType.TEXTCSV}
1729+
s3_upload_mode = ProcessingOutputHandler.get_s3_upload_mode(analysis_config)
1730+
assert s3_upload_mode == ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value

0 commit comments

Comments
 (0)