Skip to content

Commit ee04c97

Browse files
aws-byeldosnavinsoni
authored andcommitted
added data types in _AnalysisConfigGenerator methods
1 parent e1fb7a2 commit ee04c97

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

src/sagemaker/clarify.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28+
from typing import List, Union
29+
2830
from sagemaker import image_uris, s3, utils
2931
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3032

@@ -1034,7 +1036,7 @@ def _run(
10341036
def run_pre_training_bias(
10351037
self,
10361038
data_config,
1037-
bias_config,
1039+
data_bias_config,
10381040
methods="all",
10391041
wait=True,
10401042
logs=True,
@@ -1049,7 +1051,7 @@ def run_pre_training_bias(
10491051
10501052
Args:
10511053
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1052-
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1054+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
10531055
methods (str or list[str]): Selects a subset of potential metrics:
10541056
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
10551057
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1085,7 +1087,7 @@ def run_pre_training_bias(
10851087
""" # noqa E501 # pylint: disable=c0301
10861088
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
10871089
data_config,
1088-
bias_config,
1090+
data_bias_config,
10891091
methods
10901092
)
10911093
# when name is either not provided (is None) or an empty string ("")
@@ -1103,7 +1105,7 @@ def run_pre_training_bias(
11031105
def run_post_training_bias(
11041106
self,
11051107
data_config,
1106-
bias_config,
1108+
data_bias_config,
11071109
model_config,
11081110
model_predicted_label_config,
11091111
methods="all",
@@ -1123,7 +1125,7 @@ def run_post_training_bias(
11231125
11241126
Args:
11251127
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1126-
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1128+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
11271129
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
11281130
endpoint to be created.
11291131
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1166,7 +1168,7 @@ def run_post_training_bias(
11661168
""" # noqa E501 # pylint: disable=c0301
11671169
analysis_config = _AnalysisConfigGenerator.bias_post_training(
11681170
data_config,
1169-
bias_config,
1171+
data_bias_config,
11701172
model_predicted_label_config,
11711173
methods,
11721174
model_config
@@ -1377,10 +1379,10 @@ class _AnalysisConfigGenerator:
13771379
@classmethod
13781380
def explainability(
13791381
cls,
1380-
data_config,
1381-
model_config,
1382-
model_scores,
1383-
explainability_config
1382+
data_config: DataConfig,
1383+
model_config: ModelConfig,
1384+
model_scores: ModelPredictedLabelConfig,
1385+
explainability_config: ExplainabilityConfig,
13841386
):
13851387
analysis_config = data_config.get_config()
13861388
predictor_config = model_config.get_predictor_config()
@@ -1421,7 +1423,7 @@ def explainability(
14211423
return cls._common(analysis_config)
14221424

14231425
@classmethod
1424-
def bias_pre_training(cls, data_config, bias_config, methods):
1426+
def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]):
14251427
analysis_config = {
14261428
**data_config.get_config(),
14271429
**bias_config.get_config(),
@@ -1432,11 +1434,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
14321434
@classmethod
14331435
def bias_post_training(
14341436
cls,
1435-
data_config,
1436-
bias_config,
1437-
model_predicted_label_config,
1438-
methods,
1439-
model_config
1437+
data_config: DataConfig,
1438+
bias_config: BiasConfig,
1439+
model_predicted_label_config: ModelPredictedLabelConfig,
1440+
methods: Union[str, List[str]],
1441+
model_config: ModelConfig,
14401442
):
14411443
analysis_config = {
14421444
**data_config.get_config(),
@@ -1454,12 +1456,12 @@ def bias_post_training(
14541456
@classmethod
14551457
def bias(
14561458
cls,
1457-
data_config,
1458-
bias_config,
1459-
model_config,
1460-
model_predicted_label_config,
1461-
pre_training_methods="all",
1462-
post_training_methods="all",
1459+
data_config: DataConfig,
1460+
bias_config: BiasConfig,
1461+
model_config: ModelConfig,
1462+
model_predicted_label_config: ModelPredictedLabelConfig,
1463+
pre_training_methods: Union[str, List[str]] = "all",
1464+
post_training_methods: Union[str, List[str]] = "all",
14631465
):
14641466
analysis_config = {
14651467
**data_config.get_config(),

tests/unit/test_clarify.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,18 +1382,18 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
13821382
)
13831383
expected = {'dataset_type': 'text/csv',
13841384
'facet': [{'name_or_index': 'F1'}],
1385-
'group_variable': 'F2',
1386-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1387-
'joinsource_name_or_index': 'F4',
1388-
'label': 'Label',
1389-
'label_values_or_threshold': [1],
1390-
'methods': {
1391-
'report': {'name': 'report', 'title': 'Analysis Report'},
1392-
'post_training_bias': {'methods': 'all'},
1393-
'pre_training_bias': {'methods': 'all'}},
1394-
'predictor': {'initial_instance_count': 1,
1395-
'instance_type': 'ml.c5.xlarge',
1396-
'label_headers': ['success'],
1397-
'model_name': 'xgboost-model',
1398-
'probability': 'pr'}}
1385+
'group_variable': 'F2',
1386+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1387+
'joinsource_name_or_index': 'F4',
1388+
'label': 'Label',
1389+
'label_values_or_threshold': [1],
1390+
'methods': {
1391+
'report': {'name': 'report', 'title': 'Analysis Report'},
1392+
'post_training_bias': {'methods': 'all'},
1393+
'pre_training_bias': {'methods': 'all'}},
1394+
'predictor': {'initial_instance_count': 1,
1395+
'instance_type': 'ml.c5.xlarge',
1396+
'label_headers': ['success'],
1397+
'model_name': 'xgboost-model',
1398+
'probability': 'pr'}}
13991399
assert actual == expected

0 commit comments

Comments
 (0)