Skip to content

Commit 1a51f96

Browse files
committed
added data types in _AnalysisConfigGenerator methods
1 parent e3ff867 commit 1a51f96

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

src/sagemaker/clarify.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28+
from typing import List, Union
29+
2830
from sagemaker import image_uris, s3, utils
2931
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3032

@@ -971,7 +973,7 @@ def _run(
971973
def run_pre_training_bias(
972974
self,
973975
data_config,
974-
bias_config,
976+
data_bias_config,
975977
methods="all",
976978
wait=True,
977979
logs=True,
@@ -986,7 +988,7 @@ def run_pre_training_bias(
986988
987989
Args:
988990
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
989-
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
991+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
990992
methods (str or list[str]): Selects a subset of potential metrics:
991993
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
992994
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1022,7 +1024,7 @@ def run_pre_training_bias(
10221024
""" # noqa E501 # pylint: disable=c0301
10231025
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
10241026
data_config,
1025-
bias_config,
1027+
data_bias_config,
10261028
methods
10271029
)
10281030
# when name is either not provided (is None) or an empty string ("")
@@ -1040,7 +1042,7 @@ def run_pre_training_bias(
10401042
def run_post_training_bias(
10411043
self,
10421044
data_config,
1043-
bias_config,
1045+
data_bias_config,
10441046
model_config,
10451047
model_predicted_label_config,
10461048
methods="all",
@@ -1060,7 +1062,7 @@ def run_post_training_bias(
10601062
10611063
Args:
10621064
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1063-
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1065+
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
10641066
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
10651067
endpoint to be created.
10661068
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1103,7 +1105,7 @@ def run_post_training_bias(
11031105
""" # noqa E501 # pylint: disable=c0301
11041106
analysis_config = _AnalysisConfigGenerator.bias_post_training(
11051107
data_config,
1106-
bias_config,
1108+
data_bias_config,
11071109
model_predicted_label_config,
11081110
methods,
11091111
model_config
@@ -1314,10 +1316,10 @@ class _AnalysisConfigGenerator:
13141316
@classmethod
13151317
def explainability(
13161318
cls,
1317-
data_config,
1318-
model_config,
1319-
model_scores,
1320-
explainability_config
1319+
data_config: DataConfig,
1320+
model_config: ModelConfig,
1321+
model_scores: ModelPredictedLabelConfig,
1322+
explainability_config: ExplainabilityConfig,
13211323
):
13221324
analysis_config = data_config.get_config()
13231325
predictor_config = model_config.get_predictor_config()
@@ -1358,7 +1360,7 @@ def explainability(
13581360
return cls._common(analysis_config)
13591361

13601362
@classmethod
1361-
def bias_pre_training(cls, data_config, bias_config, methods):
1363+
def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]):
13621364
analysis_config = {
13631365
**data_config.get_config(),
13641366
**bias_config.get_config(),
@@ -1369,11 +1371,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
13691371
@classmethod
13701372
def bias_post_training(
13711373
cls,
1372-
data_config,
1373-
bias_config,
1374-
model_predicted_label_config,
1375-
methods,
1376-
model_config
1374+
data_config: DataConfig,
1375+
bias_config: BiasConfig,
1376+
model_predicted_label_config: ModelPredictedLabelConfig,
1377+
methods: Union[str, List[str]],
1378+
model_config: ModelConfig,
13771379
):
13781380
analysis_config = {
13791381
**data_config.get_config(),
@@ -1391,12 +1393,12 @@ def bias_post_training(
13911393
@classmethod
13921394
def bias(
13931395
cls,
1394-
data_config,
1395-
bias_config,
1396-
model_config,
1397-
model_predicted_label_config,
1398-
pre_training_methods="all",
1399-
post_training_methods="all",
1396+
data_config: DataConfig,
1397+
bias_config: BiasConfig,
1398+
model_config: ModelConfig,
1399+
model_predicted_label_config: ModelPredictedLabelConfig,
1400+
pre_training_methods: Union[str, List[str]] = "all",
1401+
post_training_methods: Union[str, List[str]] = "all",
14001402
):
14011403
analysis_config = {
14021404
**data_config.get_config(),

tests/unit/test_clarify.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,18 +1382,18 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
13821382
)
13831383
expected = {'dataset_type': 'text/csv',
13841384
'facet': [{'name_or_index': 'F1'}],
1385-
'group_variable': 'F2',
1386-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1387-
'joinsource_name_or_index': 'F4',
1388-
'label': 'Label',
1389-
'label_values_or_threshold': [1],
1390-
'methods': {
1391-
'report': {'name': 'report', 'title': 'Analysis Report'},
1392-
'post_training_bias': {'methods': 'all'},
1393-
'pre_training_bias': {'methods': 'all'}},
1394-
'predictor': {'initial_instance_count': 1,
1395-
'instance_type': 'ml.c5.xlarge',
1396-
'label_headers': ['success'],
1397-
'model_name': 'xgboost-model',
1398-
'probability': 'pr'}}
1385+
'group_variable': 'F2',
1386+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1387+
'joinsource_name_or_index': 'F4',
1388+
'label': 'Label',
1389+
'label_values_or_threshold': [1],
1390+
'methods': {
1391+
'report': {'name': 'report', 'title': 'Analysis Report'},
1392+
'post_training_bias': {'methods': 'all'},
1393+
'pre_training_bias': {'methods': 'all'}},
1394+
'predictor': {'initial_instance_count': 1,
1395+
'instance_type': 'ml.c5.xlarge',
1396+
'label_headers': ['success'],
1397+
'model_name': 'xgboost-model',
1398+
'probability': 'pr'}}
13991399
assert actual == expected

0 commit comments

Comments
 (0)