25
25
26
26
import tempfile
27
27
from abc import ABC , abstractmethod
28
- from typing import List , Union
28
+ from typing import List , Union , Dict
29
29
30
30
from sagemaker import image_uris , s3 , utils
31
31
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
@@ -173,7 +173,11 @@ def __init__(
173
173
_set (joinsource , "joinsource_name_or_index" , self .analysis_config )
174
174
_set (facet_dataset_uri , "facet_dataset_uri" , self .analysis_config )
175
175
_set (facet_headers , "facet_headers" , self .analysis_config )
176
- _set (predicted_label_dataset_uri , "predicted_label_dataset_uri" , self .analysis_config )
176
+ _set (
177
+ predicted_label_dataset_uri ,
178
+ "predicted_label_dataset_uri" ,
179
+ self .analysis_config ,
180
+ )
177
181
_set (predicted_label_headers , "predicted_label_headers" , self .analysis_config )
178
182
_set (predicted_label , "predicted_label" , self .analysis_config )
179
183
_set (excluded_columns , "excluded_columns" , self .analysis_config )
@@ -239,7 +243,8 @@ def __init__(
239
243
assert len (facet_name ) > 0 , "Please provide at least one facet"
240
244
if facet_values_or_threshold is None :
241
245
facet_list = [
242
- {"name_or_index" : single_facet_name } for single_facet_name in facet_name
246
+ {"name_or_index" : single_facet_name }
247
+ for single_facet_name in facet_name
243
248
]
244
249
elif len (facet_values_or_threshold ) == len (facet_name ):
245
250
facet_list = []
@@ -492,7 +497,10 @@ def __init__(self, features=None, grid_resolution=15, top_k_features=10):
492
497
top_k_features (int): Sets the number of top SHAP attributes used to compute
493
498
partial dependence plots.
494
499
""" # noqa E501
495
- self .pdp_config = {"grid_resolution" : grid_resolution , "top_k_features" : top_k_features }
500
+ self .pdp_config = {
501
+ "grid_resolution" : grid_resolution ,
502
+ "top_k_features" : top_k_features ,
503
+ }
496
504
if features is not None :
497
505
self .pdp_config ["features" ] = features
498
506
@@ -825,9 +833,14 @@ def __init__(
825
833
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
826
834
features. Default is None.
827
835
""" # noqa E501 # pylint: disable=c0301
828
- if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
836
+ if agg_method is not None and agg_method not in [
837
+ "mean_abs" ,
838
+ "median" ,
839
+ "mean_sq" ,
840
+ ]:
829
841
raise ValueError (
830
- f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
842
+ f"Invalid agg_method { agg_method } ."
843
+ f" Please choose mean_abs, median, or mean_sq."
831
844
)
832
845
if num_clusters is not None and baseline is not None :
833
846
raise ValueError (
@@ -923,7 +936,9 @@ def __init__(
923
936
job_name_prefix (str): Processing job name prefix.
924
937
version (str): Clarify version to use.
925
938
""" # noqa E501 # pylint: disable=c0301
926
- container_uri = image_uris .retrieve ("clarify" , sagemaker_session .boto_region_name , version )
939
+ container_uri = image_uris .retrieve (
940
+ "clarify" , sagemaker_session .boto_region_name , version
941
+ )
927
942
self ._last_analysis_config = None
928
943
self .job_name_prefix = job_name_prefix
929
944
super (SageMakerClarifyProcessor , self ).__init__ (
@@ -996,7 +1011,8 @@ def _run(
996
1011
json .dump (analysis_config , f )
997
1012
s3_analysis_config_file = _upload_analysis_config (
998
1013
analysis_config_file ,
999
- data_config .s3_analysis_config_output_path or data_config .s3_output_path ,
1014
+ data_config .s3_analysis_config_output_path
1015
+ or data_config .s3_output_path ,
1000
1016
self .sagemaker_session ,
1001
1017
kms_key ,
1002
1018
)
@@ -1168,7 +1184,11 @@ def run_post_training_bias(
1168
1184
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1169
1185
""" # noqa E501 # pylint: disable=c0301
1170
1186
analysis_config = _AnalysisConfigGenerator .bias_post_training (
1171
- data_config , data_bias_config , model_predicted_label_config , methods , model_config
1187
+ data_config ,
1188
+ data_bias_config ,
1189
+ model_predicted_label_config ,
1190
+ methods ,
1191
+ model_config ,
1172
1192
)
1173
1193
# when name is either not provided (is None) or an empty string ("")
1174
1194
job_name = job_name or utils .name_from_base (
@@ -1267,7 +1287,9 @@ def run_bias(
1267
1287
post_training_methods ,
1268
1288
)
1269
1289
# when name is either not provided (is None) or an empty string ("")
1270
- job_name = job_name or utils .name_from_base (self .job_name_prefix or "Clarify-Bias" )
1290
+ job_name = job_name or utils .name_from_base (
1291
+ self .job_name_prefix or "Clarify-Bias"
1292
+ )
1271
1293
return self ._run (
1272
1294
data_config ,
1273
1295
analysis_config ,
@@ -1450,8 +1472,8 @@ def run_bias_and_explainability(
1450
1472
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1451
1473
Defaults to str "all" to run all metrics if left unspecified.
1452
1474
model_predicted_label_config (
1453
- int or
1454
- str or
1475
+ int or
1476
+ str or
1455
1477
:class:`~sagemaker.clarify.ModelPredictedLabelConfig`
1456
1478
):
1457
1479
Index or JSONPath to locate the predicted scores in the model output. This is not
@@ -1552,11 +1574,16 @@ def explainability(
1552
1574
1553
1575
@classmethod
1554
1576
def bias_pre_training (
1555
- cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
1577
+ cls ,
1578
+ data_config : DataConfig ,
1579
+ bias_config : BiasConfig ,
1580
+ methods : Union [str , List [str ]],
1556
1581
):
1557
1582
"""Generates a config for Bias Pre Training"""
1558
1583
analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1559
- analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1584
+ analysis_config = cls ._add_methods (
1585
+ analysis_config , pre_training_methods = methods
1586
+ )
1560
1587
return analysis_config
1561
1588
1562
1589
@classmethod
@@ -1570,7 +1597,9 @@ def bias_post_training(
1570
1597
):
1571
1598
"""Generates a config for Bias Post Training"""
1572
1599
analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1573
- analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1600
+ analysis_config = cls ._add_methods (
1601
+ analysis_config , post_training_methods = methods
1602
+ )
1574
1603
analysis_config = cls ._add_predictor (
1575
1604
analysis_config , model_config , model_predicted_label_config
1576
1605
)
@@ -1599,7 +1628,12 @@ def bias(
1599
1628
return analysis_config
1600
1629
1601
1630
@classmethod
1602
- def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1631
+ def _add_predictor (
1632
+ cls ,
1633
+ analysis_config : Dict ,
1634
+ model_config : ModelConfig ,
1635
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1636
+ ):
1603
1637
"""Extends analysis config with predictor."""
1604
1638
analysis_config = {** analysis_config }
1605
1639
analysis_config ["predictor" ] = model_config .get_predictor_config ()
@@ -1618,10 +1652,12 @@ def _add_predictor(cls, analysis_config, model_config, model_predicted_label_con
1618
1652
@classmethod
1619
1653
def _add_methods (
1620
1654
cls ,
1621
- analysis_config ,
1622
- pre_training_methods = None ,
1623
- post_training_methods = None ,
1624
- explainability_config = None ,
1655
+ analysis_config : Dict ,
1656
+ pre_training_methods : Union [str , List [str ]] = "all" ,
1657
+ post_training_methods : Union [str , List [str ]] = "all" ,
1658
+ explainability_config : Union [
1659
+ ExplainabilityConfig , List [ExplainabilityConfig ]
1660
+ ] = None ,
1625
1661
report = True ,
1626
1662
):
1627
1663
"""Extends analysis config with methods."""
@@ -1640,22 +1676,35 @@ def _add_methods(
1640
1676
analysis_config ["methods" ] = {}
1641
1677
1642
1678
if report :
1643
- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1679
+ analysis_config ["methods" ]["report" ] = {
1680
+ "name" : "report" ,
1681
+ "title" : "Analysis Report" ,
1682
+ }
1644
1683
1645
1684
if pre_training_methods :
1646
- analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1685
+ analysis_config ["methods" ]["pre_training_bias" ] = {
1686
+ "methods" : pre_training_methods
1687
+ }
1647
1688
1648
1689
if post_training_methods :
1649
- analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1690
+ analysis_config ["methods" ]["post_training_bias" ] = {
1691
+ "methods" : post_training_methods
1692
+ }
1650
1693
1651
1694
if explainability_config is not None :
1652
- explainability_methods = cls ._merge_explainability_configs (explainability_config )
1653
- analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
1695
+ explainability_methods = cls ._merge_explainability_configs (
1696
+ explainability_config
1697
+ )
1698
+ analysis_config ["methods" ] = {
1699
+ ** analysis_config ["methods" ],
1700
+ ** explainability_methods ,
1701
+ }
1654
1702
return analysis_config
1655
1703
1656
1704
@classmethod
1657
1705
def _merge_explainability_configs (
1658
- cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1706
+ cls ,
1707
+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1659
1708
):
1660
1709
"""Merges explainability configs, when more than one."""
1661
1710
if isinstance (explainability_config , list ):
@@ -1671,17 +1720,24 @@ def _merge_explainability_configs(
1671
1720
"shap" not in explainability_methods
1672
1721
and "features" not in explainability_methods ["pdp" ]
1673
1722
):
1674
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1723
+ raise ValueError (
1724
+ "PDP features must be provided when ShapConfig is not provided"
1725
+ )
1675
1726
return explainability_methods
1676
1727
if (
1677
1728
isinstance (explainability_config , PDPConfig )
1678
- and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1729
+ and "features"
1730
+ not in explainability_config .get_explainability_config ()["pdp" ]
1679
1731
):
1680
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1732
+ raise ValueError (
1733
+ "PDP features must be provided when ShapConfig is not provided"
1734
+ )
1681
1735
return explainability_config .get_explainability_config ()
1682
1736
1683
1737
1684
- def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
1738
+ def _upload_analysis_config (
1739
+ analysis_config_file , s3_output_path , sagemaker_session , kms_key
1740
+ ):
1685
1741
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
1686
1742
1687
1743
Args:
0 commit comments