Skip to content

feat: added ANALYSIS_CONFIG_SCHEMA_V1_0 in clarify #3325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def read_requirements(filename):
"packaging>=20.0",
"pandas",
"pathos",
"schema",
]

# Specific use case dependencies
Expand Down
272 changes: 272 additions & 0 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from abc import ABC, abstractmethod
from typing import List, Union, Dict, Optional, Any

from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex

from sagemaker import image_uris, s3, utils
from sagemaker.session import Session
from sagemaker.network import NetworkConfig
Expand All @@ -35,6 +37,271 @@
logger = logging.getLogger(__name__)


ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"

ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
{
SchemaOptional("version"): str,
"dataset_type": And(
str,
Use(str.lower),
lambda s: s
in (
"text/csv",
"application/jsonlines",
"application/sagemakercapturejson",
"application/x-parquet",
"application/x-image",
),
),
SchemaOptional("dataset_uri"): str,
SchemaOptional("headers"): [str],
SchemaOptional("label"): Or(str, int),
# this field indicates user provides predicted_label in dataset
SchemaOptional("predicted_label"): Or(str, int),
SchemaOptional("features"): str,
SchemaOptional("label_values_or_threshold"): [Or(int, float, str)],
SchemaOptional("probability_threshold"): float,
SchemaOptional("facet"): [
{
"name_or_index": Or(str, int),
SchemaOptional("value_or_threshold"): [Or(int, float, str)],
}
],
SchemaOptional("facet_dataset_uri"): str,
SchemaOptional("facet_headers"): [str],
SchemaOptional("predicted_label_dataset_uri"): str,
SchemaOptional("predicted_label_headers"): [str],
SchemaOptional("excluded_columns"): [Or(int, str)],
SchemaOptional("joinsource_name_or_index"): Or(str, int),
SchemaOptional("group_variable"): Or(str, int),
"methods": {
SchemaOptional("shap"): {
SchemaOptional("baseline"): Or(
# URI of the baseline data file
str,
# Inplace baseline data (a list of something)
[
Or(
# CSV row
[Or(int, float, str, None)],
# JSON row (any JSON object). As I write this only
# SageMaker JSONLines Dense Format ([1])
# is supported and the validation is NOT done
# by the schema but by the data loader.
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
{object: object},
)
],
),
SchemaOptional("num_clusters"): int,
SchemaOptional("use_logit"): bool,
SchemaOptional("num_samples"): int,
SchemaOptional("agg_method"): And(
str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
),
SchemaOptional("save_local_shap_values"): bool,
SchemaOptional("text_config"): {
"granularity": And(
str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
),
"language": And(
str,
Use(str.lower),
lambda s: s
in (
"chinese",
"zh",
"danish",
"da",
"dutch",
"nl",
"english",
"en",
"french",
"fr",
"german",
"de",
"greek",
"el",
"italian",
"it",
"japanese",
"ja",
"lithuanian",
"lt",
"multi-language",
"xx",
"norwegian bokmål",
"nb",
"polish",
"pl",
"portuguese",
"pt",
"romanian",
"ro",
"russian",
"ru",
"spanish",
"es",
"afrikaans",
"af",
"albanian",
"sq",
"arabic",
"ar",
"armenian",
"hy",
"basque",
"eu",
"bengali",
"bn",
"bulgarian",
"bg",
"catalan",
"ca",
"croatian",
"hr",
"czech",
"cs",
"estonian",
"et",
"finnish",
"fi",
"gujarati",
"gu",
"hebrew",
"he",
"hindi",
"hi",
"hungarian",
"hu",
"icelandic",
"is",
"indonesian",
"id",
"irish",
"ga",
"kannada",
"kn",
"kyrgyz",
"ky",
"latvian",
"lv",
"ligurian",
"lij",
"luxembourgish",
"lb",
"macedonian",
"mk",
"malayalam",
"ml",
"marathi",
"mr",
"nepali",
"ne",
"persian",
"fa",
"sanskrit",
"sa",
"serbian",
"sr",
"setswana",
"tn",
"sinhala",
"si",
"slovak",
"sk",
"slovenian",
"sl",
"swedish",
"sv",
"tagalog",
"tl",
"tamil",
"ta",
"tatar",
"tt",
"telugu",
"te",
"thai",
"th",
"turkish",
"tr",
"ukrainian",
"uk",
"urdu",
"ur",
"vietnamese",
"vi",
"yoruba",
"yo",
),
),
SchemaOptional("max_top_tokens"): int,
},
SchemaOptional("image_config"): {
SchemaOptional("num_segments"): int,
SchemaOptional("segment_compactness"): int,
SchemaOptional("feature_extraction_method"): str,
SchemaOptional("model_type"): str,
SchemaOptional("max_objects"): int,
SchemaOptional("iou_threshold"): float,
SchemaOptional("context"): float,
SchemaOptional("debug"): {
SchemaOptional("image_names"): [str],
SchemaOptional("class_ids"): [int],
SchemaOptional("sample_from"): int,
SchemaOptional("sample_to"): int,
},
},
SchemaOptional("seed"): int,
},
SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
SchemaOptional("pdp"): {
"grid_resolution": int,
SchemaOptional("features"): [Or(str, int)],
SchemaOptional("top_k_features"): int,
},
SchemaOptional("report"): {"name": str, SchemaOptional("title"): str},
},
SchemaOptional("predictor"): {
SchemaOptional("endpoint_name"): str,
SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
SchemaOptional("model_name"): str,
SchemaOptional("target_model"): str,
SchemaOptional("instance_type"): str,
SchemaOptional("initial_instance_count"): int,
SchemaOptional("accelerator_type"): str,
SchemaOptional("content_type"): And(
str,
Use(str.lower),
lambda s: s
in (
"text/csv",
"application/jsonlines",
"image/jpeg",
"image/jpg",
"image/png",
"application/x-npy",
),
),
SchemaOptional("accept_type"): And(
str,
Use(str.lower),
lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
),
SchemaOptional("label"): Or(str, int),
SchemaOptional("probability"): Or(str, int),
SchemaOptional("label_headers"): [Or(str, int)],
SchemaOptional("content_template"): Or(str, {str: str}),
SchemaOptional("custom_attributes"): str,
},
}
)


class DataConfig:
"""Config object related to configurations of the input and output dataset."""

Expand Down Expand Up @@ -926,6 +1193,7 @@ def __init__(
network_config: Optional[NetworkConfig] = None,
job_name_prefix: Optional[str] = None,
version: Optional[str] = None,
skip_early_validation: bool = False,
):
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.

Expand Down Expand Up @@ -967,10 +1235,12 @@ def __init__(
inter-container traffic, security group IDs, and subnets.
job_name_prefix (str): Processing job name prefix.
version (str): Clarify version to use.
skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json.
""" # noqa E501 # pylint: disable=c0301
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
self._last_analysis_config = None
self.job_name_prefix = job_name_prefix
self.skip_early_validation = skip_early_validation
super(SageMakerClarifyProcessor, self).__init__(
role,
container_uri,
Expand Down Expand Up @@ -1034,6 +1304,8 @@ def _run(
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
self._last_analysis_config = analysis_config
logger.info("Analysis Config: %s", analysis_config)
if not self.skip_early_validation:
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)

with tempfile.TemporaryDirectory() as tmpdirname:
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file(
processor_run, sagemaker_session, clarify_processor, data_config
):
analysis_config = {
"dataset_type": "text/csv",
"methods": {"post_training_bias": {"methods": "all"}},
}
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:
Expand Down