25
25
26
26
import tempfile
27
27
from abc import ABC , abstractmethod
28
+ from typing import List , Union
29
+
28
30
from sagemaker import image_uris , s3 , utils
29
31
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
30
32
@@ -1034,7 +1036,7 @@ def _run(
1034
1036
def run_pre_training_bias (
1035
1037
self ,
1036
1038
data_config ,
1037
- bias_config ,
1039
+ data_bias_config ,
1038
1040
methods = "all" ,
1039
1041
wait = True ,
1040
1042
logs = True ,
@@ -1049,7 +1051,7 @@ def run_pre_training_bias(
1049
1051
1050
1052
Args:
1051
1053
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1052
- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1054
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1053
1055
methods (str or list[str]): Selects a subset of potential metrics:
1054
1056
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1055
1057
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1085,7 +1087,7 @@ def run_pre_training_bias(
1085
1087
""" # noqa E501 # pylint: disable=c0301
1086
1088
analysis_config = _AnalysisConfigGenerator .bias_pre_training (
1087
1089
data_config ,
1088
- bias_config ,
1090
+ data_bias_config ,
1089
1091
methods
1090
1092
)
1091
1093
# when name is either not provided (is None) or an empty string ("")
@@ -1103,7 +1105,7 @@ def run_pre_training_bias(
1103
1105
def run_post_training_bias (
1104
1106
self ,
1105
1107
data_config ,
1106
- bias_config ,
1108
+ data_bias_config ,
1107
1109
model_config ,
1108
1110
model_predicted_label_config ,
1109
1111
methods = "all" ,
@@ -1123,7 +1125,7 @@ def run_post_training_bias(
1123
1125
1124
1126
Args:
1125
1127
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1126
- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1128
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1127
1129
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1128
1130
endpoint to be created.
1129
1131
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1166,7 +1168,7 @@ def run_post_training_bias(
1166
1168
""" # noqa E501 # pylint: disable=c0301
1167
1169
analysis_config = _AnalysisConfigGenerator .bias_post_training (
1168
1170
data_config ,
1169
- bias_config ,
1171
+ data_bias_config ,
1170
1172
model_predicted_label_config ,
1171
1173
methods ,
1172
1174
model_config
@@ -1377,10 +1379,10 @@ class _AnalysisConfigGenerator:
1377
1379
@classmethod
1378
1380
def explainability (
1379
1381
cls ,
1380
- data_config ,
1381
- model_config ,
1382
- model_scores ,
1383
- explainability_config
1382
+ data_config : DataConfig ,
1383
+ model_config : ModelConfig ,
1384
+ model_scores : ModelPredictedLabelConfig ,
1385
+ explainability_config : ExplainabilityConfig ,
1384
1386
):
1385
1387
analysis_config = data_config .get_config ()
1386
1388
predictor_config = model_config .get_predictor_config ()
@@ -1421,7 +1423,7 @@ def explainability(
1421
1423
return cls ._common (analysis_config )
1422
1424
1423
1425
@classmethod
1424
- def bias_pre_training (cls , data_config , bias_config , methods ):
1426
+ def bias_pre_training (cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [ str , List [ str ]] ):
1425
1427
analysis_config = {
1426
1428
** data_config .get_config (),
1427
1429
** bias_config .get_config (),
@@ -1432,11 +1434,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
1432
1434
@classmethod
1433
1435
def bias_post_training (
1434
1436
cls ,
1435
- data_config ,
1436
- bias_config ,
1437
- model_predicted_label_config ,
1438
- methods ,
1439
- model_config
1437
+ data_config : DataConfig ,
1438
+ bias_config : BiasConfig ,
1439
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1440
+ methods : Union [ str , List [ str ]] ,
1441
+ model_config : ModelConfig ,
1440
1442
):
1441
1443
analysis_config = {
1442
1444
** data_config .get_config (),
@@ -1454,12 +1456,12 @@ def bias_post_training(
1454
1456
@classmethod
1455
1457
def bias (
1456
1458
cls ,
1457
- data_config ,
1458
- bias_config ,
1459
- model_config ,
1460
- model_predicted_label_config ,
1461
- pre_training_methods = "all" ,
1462
- post_training_methods = "all" ,
1459
+ data_config : DataConfig ,
1460
+ bias_config : BiasConfig ,
1461
+ model_config : ModelConfig ,
1462
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1463
+ pre_training_methods : Union [ str , List [ str ]] = "all" ,
1464
+ post_training_methods : Union [ str , List [ str ]] = "all" ,
1463
1465
):
1464
1466
analysis_config = {
1465
1467
** data_config .get_config (),
0 commit comments