Skip to content

Commit e502748

Browse files
committed
added data types in _AnalysisConfigGenerator methods
1 parent e3ff867 commit e502748

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/sagemaker/clarify.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28+
from typing import List, Union
29+
2830
from sagemaker import image_uris, s3, utils
2931
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3032

@@ -1314,10 +1316,10 @@ class _AnalysisConfigGenerator:
13141316
@classmethod
13151317
def explainability(
13161318
cls,
1317-
data_config,
1318-
model_config,
1319-
model_scores,
1320-
explainability_config
1319+
data_config: DataConfig,
1320+
model_config: ModelConfig,
1321+
model_scores: ModelPredictedLabelConfig,
1322+
explainability_config: ExplainabilityConfig,
13211323
):
13221324
analysis_config = data_config.get_config()
13231325
predictor_config = model_config.get_predictor_config()
@@ -1358,7 +1360,7 @@ def explainability(
13581360
return cls._common(analysis_config)
13591361

13601362
@classmethod
1361-
def bias_pre_training(cls, data_config, bias_config, methods):
1363+
def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: List[str]):
13621364
analysis_config = {
13631365
**data_config.get_config(),
13641366
**bias_config.get_config(),
@@ -1369,11 +1371,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
13691371
@classmethod
13701372
def bias_post_training(
13711373
cls,
1372-
data_config,
1373-
bias_config,
1374-
model_predicted_label_config,
1375-
methods,
1376-
model_config
1374+
data_config: DataConfig,
1375+
bias_config: BiasConfig,
1376+
model_predicted_label_config: ModelPredictedLabelConfig,
1377+
methods: List[str],
1378+
model_config: ModelConfig,
13771379
):
13781380
analysis_config = {
13791381
**data_config.get_config(),
@@ -1391,12 +1393,12 @@ def bias_post_training(
13911393
@classmethod
13921394
def bias(
13931395
cls,
1394-
data_config,
1395-
bias_config,
1396-
model_config,
1397-
model_predicted_label_config,
1398-
pre_training_methods="all",
1399-
post_training_methods="all",
1396+
data_config: DataConfig,
1397+
bias_config: BiasConfig,
1398+
model_config: ModelConfig,
1399+
model_predicted_label_config: ModelPredictedLabelConfig,
1400+
pre_training_methods: Union[str, List[str]] = "all",
1401+
post_training_methods: Union[str, List[str]] = "all",
14001402
):
14011403
analysis_config = {
14021404
**data_config.get_config(),

0 commit comments

Comments
 (0)