25
25
26
26
import tempfile
27
27
from abc import ABC , abstractmethod
28
+ from typing import List , Union
29
+
28
30
from sagemaker import image_uris , s3 , utils
29
31
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
30
32
@@ -971,7 +973,7 @@ def _run(
971
973
def run_pre_training_bias (
972
974
self ,
973
975
data_config ,
974
- bias_config ,
976
+ data_bias_config ,
975
977
methods = "all" ,
976
978
wait = True ,
977
979
logs = True ,
@@ -986,7 +988,7 @@ def run_pre_training_bias(
986
988
987
989
Args:
988
990
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
989
- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
991
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
990
992
methods (str or list[str]): Selects a subset of potential metrics:
991
993
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
992
994
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1022,7 +1024,7 @@ def run_pre_training_bias(
1022
1024
""" # noqa E501 # pylint: disable=c0301
1023
1025
analysis_config = _AnalysisConfigGenerator .bias_pre_training (
1024
1026
data_config ,
1025
- bias_config ,
1027
+ data_bias_config ,
1026
1028
methods
1027
1029
)
1028
1030
# when name is either not provided (is None) or an empty string ("")
@@ -1040,7 +1042,7 @@ def run_pre_training_bias(
1040
1042
def run_post_training_bias (
1041
1043
self ,
1042
1044
data_config ,
1043
- bias_config ,
1045
+ data_bias_config ,
1044
1046
model_config ,
1045
1047
model_predicted_label_config ,
1046
1048
methods = "all" ,
@@ -1060,7 +1062,7 @@ def run_post_training_bias(
1060
1062
1061
1063
Args:
1062
1064
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1063
- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1065
+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1064
1066
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1065
1067
endpoint to be created.
1066
1068
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1103,7 +1105,7 @@ def run_post_training_bias(
1103
1105
""" # noqa E501 # pylint: disable=c0301
1104
1106
analysis_config = _AnalysisConfigGenerator .bias_post_training (
1105
1107
data_config ,
1106
- bias_config ,
1108
+ data_bias_config ,
1107
1109
model_predicted_label_config ,
1108
1110
methods ,
1109
1111
model_config
@@ -1314,10 +1316,10 @@ class _AnalysisConfigGenerator:
1314
1316
@classmethod
1315
1317
def explainability (
1316
1318
cls ,
1317
- data_config ,
1318
- model_config ,
1319
- model_scores ,
1320
- explainability_config
1319
+ data_config : DataConfig ,
1320
+ model_config : ModelConfig ,
1321
+ model_scores : ModelPredictedLabelConfig ,
1322
+ explainability_config : ExplainabilityConfig ,
1321
1323
):
1322
1324
analysis_config = data_config .get_config ()
1323
1325
predictor_config = model_config .get_predictor_config ()
@@ -1358,7 +1360,7 @@ def explainability(
1358
1360
return cls ._common (analysis_config )
1359
1361
1360
1362
@classmethod
1361
- def bias_pre_training (cls , data_config , bias_config , methods ):
1363
+ def bias_pre_training (cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [ str , List [ str ]] ):
1362
1364
analysis_config = {
1363
1365
** data_config .get_config (),
1364
1366
** bias_config .get_config (),
@@ -1369,11 +1371,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
1369
1371
@classmethod
1370
1372
def bias_post_training (
1371
1373
cls ,
1372
- data_config ,
1373
- bias_config ,
1374
- model_predicted_label_config ,
1375
- methods ,
1376
- model_config
1374
+ data_config : DataConfig ,
1375
+ bias_config : BiasConfig ,
1376
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1377
+ methods : Union [ str , List [ str ]] ,
1378
+ model_config : ModelConfig ,
1377
1379
):
1378
1380
analysis_config = {
1379
1381
** data_config .get_config (),
@@ -1391,12 +1393,12 @@ def bias_post_training(
1391
1393
@classmethod
1392
1394
def bias (
1393
1395
cls ,
1394
- data_config ,
1395
- bias_config ,
1396
- model_config ,
1397
- model_predicted_label_config ,
1398
- pre_training_methods = "all" ,
1399
- post_training_methods = "all" ,
1396
+ data_config : DataConfig ,
1397
+ bias_config : BiasConfig ,
1398
+ model_config : ModelConfig ,
1399
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1400
+ pre_training_methods : Union [ str , List [ str ]] = "all" ,
1401
+ post_training_methods : Union [ str , List [ str ]] = "all" ,
1400
1402
):
1401
1403
analysis_config = {
1402
1404
** data_config .get_config (),
0 commit comments