Skip to content

Commit 6f82a43

Browse files
committed
Added data types
1 parent 67708bf commit 6f82a43

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/sagemaker/clarify.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28-
from typing import List, Union
28+
from typing import List, Union, Dict
2929

3030
from sagemaker import image_uris, s3, utils
3131
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
@@ -1450,8 +1450,8 @@ def run_bias_and_explainability(
14501450
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
14511451
Defaults to str "all" to run all metrics if left unspecified.
14521452
model_predicted_label_config (
1453-
int or
1454-
str or
1453+
int or
1454+
str or
14551455
:class:`~sagemaker.clarify.ModelPredictedLabelConfig`
14561456
):
14571457
Index or JSONPath to locate the predicted scores in the model output. This is not
@@ -1599,7 +1599,13 @@ def bias(
15991599
return analysis_config
16001600

16011601
@classmethod
1602-
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1602+
def _add_predictor(
1603+
cls,
1604+
analysis_config: Dict,
1605+
model_config: ModelConfig,
1606+
model_predicted_label_config:
1607+
ModelPredictedLabelConfig
1608+
):
16031609
"""Extends analysis config with predictor."""
16041610
analysis_config = {**analysis_config}
16051611
analysis_config["predictor"] = model_config.get_predictor_config()
@@ -1618,10 +1624,10 @@ def _add_predictor(cls, analysis_config, model_config, model_predicted_label_con
16181624
@classmethod
16191625
def _add_methods(
16201626
cls,
1621-
analysis_config,
1622-
pre_training_methods=None,
1623-
post_training_methods=None,
1624-
explainability_config=None,
1627+
analysis_config: Dict,
1628+
pre_training_methods: Union[str, List[str]] = "all",
1629+
post_training_methods: Union[str, List[str]] = "all",
1630+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None,
16251631
report=True,
16261632
):
16271633
"""Extends analysis config with methods."""

0 commit comments

Comments
 (0)