Skip to content

feature: added _AnalysisConfigGenerator for clarify #3271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 139 additions & 63 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import tempfile
from abc import ABC, abstractmethod
from typing import List, Union

from sagemaker import image_uris, s3, utils
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor

Expand Down Expand Up @@ -922,6 +924,7 @@ def __init__(
version (str): Clarify version to use.
""" # noqa E501 # pylint: disable=c0301
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
self._last_analysis_config = None
self.job_name_prefix = job_name_prefix
super(SageMakerClarifyProcessor, self).__init__(
role,
Expand Down Expand Up @@ -983,10 +986,10 @@ def _run(
the Trial Component will be unassociated.
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
"""
analysis_config["methods"]["report"] = {
"name": "report",
"title": "Analysis Report",
}
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
self._last_analysis_config = analysis_config
logger.info("Analysis Config: %s", analysis_config)

with tempfile.TemporaryDirectory() as tmpdirname:
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
with open(analysis_config_file, "w") as f:
Expand Down Expand Up @@ -1083,14 +1086,13 @@ def run_pre_training_bias(
the Trial Component will be unassociated.
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
""" # noqa E501 # pylint: disable=c0301
analysis_config = data_config.get_config()
analysis_config.update(data_bias_config.get_config())
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
if job_name is None:
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
data_config, data_bias_config, methods
)
# when name is either not provided (is None) or an empty string ("")
job_name = job_name or utils.name_from_base(
self.job_name_prefix or "Clarify-Pretraining-Bias"
)
return self._run(
data_config,
analysis_config,
Expand Down Expand Up @@ -1165,21 +1167,13 @@ def run_post_training_bias(
the Trial Component will be unassociated.
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
""" # noqa E501 # pylint: disable=c0301
analysis_config = data_config.get_config()
analysis_config.update(data_bias_config.get_config())
(
probability_threshold,
predictor_config,
) = model_predicted_label_config.get_predictor_config()
predictor_config.update(model_config.get_predictor_config())
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
analysis_config["predictor"] = predictor_config
_set(probability_threshold, "probability_threshold", analysis_config)
if job_name is None:
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
analysis_config = _AnalysisConfigGenerator.bias_post_training(
data_config, data_bias_config, model_predicted_label_config, methods, model_config
)
# when name is either not provided (is None) or an empty string ("")
job_name = job_name or utils.name_from_base(
self.job_name_prefix or "Clarify-Posttraining-Bias"
)
return self._run(
data_config,
analysis_config,
Expand Down Expand Up @@ -1264,28 +1258,16 @@ def run_bias(
the Trial Component will be unassociated.
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
""" # noqa E501 # pylint: disable=c0301
analysis_config = data_config.get_config()
analysis_config.update(bias_config.get_config())
analysis_config["predictor"] = model_config.get_predictor_config()
if model_predicted_label_config:
(
probability_threshold,
predictor_config,
) = model_predicted_label_config.get_predictor_config()
if predictor_config:
analysis_config["predictor"].update(predictor_config)
if probability_threshold is not None:
analysis_config["probability_threshold"] = probability_threshold

analysis_config["methods"] = {
"pre_training_bias": {"methods": pre_training_methods},
"post_training_bias": {"methods": post_training_methods},
}
if job_name is None:
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Bias")
analysis_config = _AnalysisConfigGenerator.bias(
data_config,
bias_config,
model_config,
model_predicted_label_config,
pre_training_methods,
post_training_methods,
)
# when name is either not provided (is None) or an empty string ("")
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias")
return self._run(
data_config,
analysis_config,
Expand Down Expand Up @@ -1370,6 +1352,36 @@ def run_explainability(
the Trial Component will be unassociated.
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
""" # noqa E501 # pylint: disable=c0301
analysis_config = _AnalysisConfigGenerator.explainability(
data_config, model_config, model_scores, explainability_config
)
# when name is either not provided (is None) or an empty string ("")
job_name = job_name or utils.name_from_base(
self.job_name_prefix or "Clarify-Explainability"
)
return self._run(
data_config,
analysis_config,
wait,
logs,
job_name,
kms_key,
experiment_config,
)


class _AnalysisConfigGenerator:
"""Creates analysis_config objects for different type of runs."""

@classmethod
def explainability(
cls,
data_config: DataConfig,
model_config: ModelConfig,
model_scores: ModelPredictedLabelConfig,
explainability_config: ExplainabilityConfig,
):
"""Generates a config for Explainability"""
analysis_config = data_config.get_config()
predictor_config = model_config.get_predictor_config()
if isinstance(model_scores, ModelPredictedLabelConfig):
Expand Down Expand Up @@ -1406,20 +1418,84 @@ def run_explainability(
explainability_methods = explainability_config.get_explainability_config()
analysis_config["methods"] = explainability_methods
analysis_config["predictor"] = predictor_config
if job_name is None:
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Explainability")
return self._run(
data_config,
analysis_config,
wait,
logs,
job_name,
kms_key,
experiment_config,
)
return cls._common(analysis_config)

@classmethod
def bias_pre_training(
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
):
"""Generates a config for Bias Pre Training"""
analysis_config = {
**data_config.get_config(),
**bias_config.get_config(),
"methods": {"pre_training_bias": {"methods": methods}},
}
return cls._common(analysis_config)

@classmethod
def bias_post_training(
cls,
data_config: DataConfig,
bias_config: BiasConfig,
model_predicted_label_config: ModelPredictedLabelConfig,
methods: Union[str, List[str]],
model_config: ModelConfig,
):
"""Generates a config for Bias Post Training"""
analysis_config = {
**data_config.get_config(),
**bias_config.get_config(),
"predictor": {**model_config.get_predictor_config()},
"methods": {"post_training_bias": {"methods": methods}},
}
if model_predicted_label_config:
(
probability_threshold,
predictor_config,
) = model_predicted_label_config.get_predictor_config()
if predictor_config:
analysis_config["predictor"].update(predictor_config)
_set(probability_threshold, "probability_threshold", analysis_config)
return cls._common(analysis_config)

@classmethod
def bias(
cls,
data_config: DataConfig,
bias_config: BiasConfig,
model_config: ModelConfig,
model_predicted_label_config: ModelPredictedLabelConfig,
pre_training_methods: Union[str, List[str]] = "all",
post_training_methods: Union[str, List[str]] = "all",
):
"""Generates a config for Bias"""
analysis_config = {
**data_config.get_config(),
**bias_config.get_config(),
"predictor": model_config.get_predictor_config(),
"methods": {
"pre_training_bias": {"methods": pre_training_methods},
"post_training_bias": {"methods": post_training_methods},
},
}
if model_predicted_label_config:
(
probability_threshold,
predictor_config,
) = model_predicted_label_config.get_predictor_config()
if predictor_config:
analysis_config["predictor"].update(predictor_config)
_set(probability_threshold, "probability_threshold", analysis_config)
return cls._common(analysis_config)

@staticmethod
def _common(analysis_config):
"""Extends analysis config with common values"""
analysis_config["methods"]["report"] = {
"name": "report",
"title": "Analysis Report",
}
return analysis_config


def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
Expand Down
Loading