|
25 | 25 |
|
26 | 26 | import tempfile
|
27 | 27 | from abc import ABC, abstractmethod
|
| 28 | +from typing import List, Union |
| 29 | + |
28 | 30 | from sagemaker import image_uris, s3, utils
|
29 | 31 | from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
|
30 | 32 |
|
@@ -63,7 +65,6 @@ def __init__(
|
63 | 65 | label (str): Target attribute of the model required by bias metrics.
|
64 | 66 | Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
|
65 | 67 | *Required parameter* except for when the input dataset does not contain the label.
|
66 |
| - Cannot be used at the same time as ``predicted_label``. |
67 | 68 | features (str): JSONPath for locating the feature columns for bias metrics if the
|
68 | 69 | dataset format is JSONLines.
|
69 | 70 | dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
|
@@ -103,7 +104,7 @@ def __init__(
|
103 | 104 | predicted_label (str or int): Predicted label of the target attribute of the model
|
104 | 105 | required for running bias analysis. Specified as column name or index for CSV data.
|
105 | 106 | Clarify uses the predicted labels directly instead of making model inference API
|
106 |
| - calls. Cannot be used at the same time as ``label``. |
| 107 | + calls. |
107 | 108 | excluded_columns (list[int] or list[str]): A list of names or indices of the columns
|
108 | 109 | which are to be excluded from making model inference API calls.
|
109 | 110 |
|
@@ -922,6 +923,7 @@ def __init__(
|
922 | 923 | version (str): Clarify version to use.
|
923 | 924 | """ # noqa E501 # pylint: disable=c0301
|
924 | 925 | container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
|
| 926 | + self._last_analysis_config = None |
925 | 927 | self.job_name_prefix = job_name_prefix
|
926 | 928 | super(SageMakerClarifyProcessor, self).__init__(
|
927 | 929 | role,
|
@@ -983,10 +985,10 @@ def _run(
|
983 | 985 | the Trial Component will be unassociated.
|
984 | 986 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
985 | 987 | """
|
986 |
| - analysis_config["methods"]["report"] = { |
987 |
| - "name": "report", |
988 |
| - "title": "Analysis Report", |
989 |
| - } |
| 988 | + # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket |
| 989 | + self._last_analysis_config = analysis_config |
| 990 | + logger.info("Analysis Config: %s", analysis_config) |
| 991 | + |
990 | 992 | with tempfile.TemporaryDirectory() as tmpdirname:
|
991 | 993 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
|
992 | 994 | with open(analysis_config_file, "w") as f:
|
@@ -1083,14 +1085,13 @@ def run_pre_training_bias(
|
1083 | 1085 | the Trial Component will be unassociated.
|
1084 | 1086 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
1085 | 1087 | """ # noqa E501 # pylint: disable=c0301
|
1086 |
| - analysis_config = data_config.get_config() |
1087 |
| - analysis_config.update(data_bias_config.get_config()) |
1088 |
| - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} |
1089 |
| - if job_name is None: |
1090 |
| - if self.job_name_prefix: |
1091 |
| - job_name = utils.name_from_base(self.job_name_prefix) |
1092 |
| - else: |
1093 |
| - job_name = utils.name_from_base("Clarify-Pretraining-Bias") |
| 1088 | + analysis_config = _AnalysisConfigGenerator.bias_pre_training( |
| 1089 | + data_config, data_bias_config, methods |
| 1090 | + ) |
| 1091 | + # when name is either not provided (is None) or an empty string ("") |
| 1092 | + job_name = job_name or utils.name_from_base( |
| 1093 | + self.job_name_prefix or "Clarify-Pretraining-Bias" |
| 1094 | + ) |
1094 | 1095 | return self._run(
|
1095 | 1096 | data_config,
|
1096 | 1097 | analysis_config,
|
@@ -1165,21 +1166,13 @@ def run_post_training_bias(
|
1165 | 1166 | the Trial Component will be unassociated.
|
1166 | 1167 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
1167 | 1168 | """ # noqa E501 # pylint: disable=c0301
|
1168 |
| - analysis_config = data_config.get_config() |
1169 |
| - analysis_config.update(data_bias_config.get_config()) |
1170 |
| - ( |
1171 |
| - probability_threshold, |
1172 |
| - predictor_config, |
1173 |
| - ) = model_predicted_label_config.get_predictor_config() |
1174 |
| - predictor_config.update(model_config.get_predictor_config()) |
1175 |
| - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} |
1176 |
| - analysis_config["predictor"] = predictor_config |
1177 |
| - _set(probability_threshold, "probability_threshold", analysis_config) |
1178 |
| - if job_name is None: |
1179 |
| - if self.job_name_prefix: |
1180 |
| - job_name = utils.name_from_base(self.job_name_prefix) |
1181 |
| - else: |
1182 |
| - job_name = utils.name_from_base("Clarify-Posttraining-Bias") |
| 1169 | + analysis_config = _AnalysisConfigGenerator.bias_post_training( |
| 1170 | + data_config, data_bias_config, model_predicted_label_config, methods, model_config |
| 1171 | + ) |
| 1172 | + # when name is either not provided (is None) or an empty string ("") |
| 1173 | + job_name = job_name or utils.name_from_base( |
| 1174 | + self.job_name_prefix or "Clarify-Posttraining-Bias" |
| 1175 | + ) |
1183 | 1176 | return self._run(
|
1184 | 1177 | data_config,
|
1185 | 1178 | analysis_config,
|
@@ -1264,28 +1257,16 @@ def run_bias(
|
1264 | 1257 | the Trial Component will be unassociated.
|
1265 | 1258 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
1266 | 1259 | """ # noqa E501 # pylint: disable=c0301
|
1267 |
| - analysis_config = data_config.get_config() |
1268 |
| - analysis_config.update(bias_config.get_config()) |
1269 |
| - analysis_config["predictor"] = model_config.get_predictor_config() |
1270 |
| - if model_predicted_label_config: |
1271 |
| - ( |
1272 |
| - probability_threshold, |
1273 |
| - predictor_config, |
1274 |
| - ) = model_predicted_label_config.get_predictor_config() |
1275 |
| - if predictor_config: |
1276 |
| - analysis_config["predictor"].update(predictor_config) |
1277 |
| - if probability_threshold is not None: |
1278 |
| - analysis_config["probability_threshold"] = probability_threshold |
1279 |
| - |
1280 |
| - analysis_config["methods"] = { |
1281 |
| - "pre_training_bias": {"methods": pre_training_methods}, |
1282 |
| - "post_training_bias": {"methods": post_training_methods}, |
1283 |
| - } |
1284 |
| - if job_name is None: |
1285 |
| - if self.job_name_prefix: |
1286 |
| - job_name = utils.name_from_base(self.job_name_prefix) |
1287 |
| - else: |
1288 |
| - job_name = utils.name_from_base("Clarify-Bias") |
| 1260 | + analysis_config = _AnalysisConfigGenerator.bias( |
| 1261 | + data_config, |
| 1262 | + bias_config, |
| 1263 | + model_config, |
| 1264 | + model_predicted_label_config, |
| 1265 | + pre_training_methods, |
| 1266 | + post_training_methods, |
| 1267 | + ) |
| 1268 | + # when name is either not provided (is None) or an empty string ("") |
| 1269 | + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias") |
1289 | 1270 | return self._run(
|
1290 | 1271 | data_config,
|
1291 | 1272 | analysis_config,
|
@@ -1370,6 +1351,36 @@ def run_explainability(
|
1370 | 1351 | the Trial Component will be unassociated.
|
1371 | 1352 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
|
1372 | 1353 | """ # noqa E501 # pylint: disable=c0301
|
| 1354 | + analysis_config = _AnalysisConfigGenerator.explainability( |
| 1355 | + data_config, model_config, model_scores, explainability_config |
| 1356 | + ) |
| 1357 | + # when name is either not provided (is None) or an empty string ("") |
| 1358 | + job_name = job_name or utils.name_from_base( |
| 1359 | + self.job_name_prefix or "Clarify-Explainability" |
| 1360 | + ) |
| 1361 | + return self._run( |
| 1362 | + data_config, |
| 1363 | + analysis_config, |
| 1364 | + wait, |
| 1365 | + logs, |
| 1366 | + job_name, |
| 1367 | + kms_key, |
| 1368 | + experiment_config, |
| 1369 | + ) |
| 1370 | + |
| 1371 | + |
| 1372 | +class _AnalysisConfigGenerator: |
| 1373 | + """Creates analysis_config objects for different type of runs.""" |
| 1374 | + |
| 1375 | + @classmethod |
| 1376 | + def explainability( |
| 1377 | + cls, |
| 1378 | + data_config: DataConfig, |
| 1379 | + model_config: ModelConfig, |
| 1380 | + model_scores: ModelPredictedLabelConfig, |
| 1381 | + explainability_config: ExplainabilityConfig, |
| 1382 | + ): |
| 1383 | + """Generates a config for Explainability""" |
1373 | 1384 | analysis_config = data_config.get_config()
|
1374 | 1385 | predictor_config = model_config.get_predictor_config()
|
1375 | 1386 | if isinstance(model_scores, ModelPredictedLabelConfig):
|
@@ -1406,20 +1417,84 @@ def run_explainability(
|
1406 | 1417 | explainability_methods = explainability_config.get_explainability_config()
|
1407 | 1418 | analysis_config["methods"] = explainability_methods
|
1408 | 1419 | analysis_config["predictor"] = predictor_config
|
1409 |
| - if job_name is None: |
1410 |
| - if self.job_name_prefix: |
1411 |
| - job_name = utils.name_from_base(self.job_name_prefix) |
1412 |
| - else: |
1413 |
| - job_name = utils.name_from_base("Clarify-Explainability") |
1414 |
| - return self._run( |
1415 |
| - data_config, |
1416 |
| - analysis_config, |
1417 |
| - wait, |
1418 |
| - logs, |
1419 |
| - job_name, |
1420 |
| - kms_key, |
1421 |
| - experiment_config, |
1422 |
| - ) |
| 1420 | + return cls._common(analysis_config) |
| 1421 | + |
| 1422 | + @classmethod |
| 1423 | + def bias_pre_training( |
| 1424 | + cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] |
| 1425 | + ): |
| 1426 | + """Generates a config for Bias Pre Training""" |
| 1427 | + analysis_config = { |
| 1428 | + **data_config.get_config(), |
| 1429 | + **bias_config.get_config(), |
| 1430 | + "methods": {"pre_training_bias": {"methods": methods}}, |
| 1431 | + } |
| 1432 | + return cls._common(analysis_config) |
| 1433 | + |
| 1434 | + @classmethod |
| 1435 | + def bias_post_training( |
| 1436 | + cls, |
| 1437 | + data_config: DataConfig, |
| 1438 | + bias_config: BiasConfig, |
| 1439 | + model_predicted_label_config: ModelPredictedLabelConfig, |
| 1440 | + methods: Union[str, List[str]], |
| 1441 | + model_config: ModelConfig, |
| 1442 | + ): |
| 1443 | + """Generates a config for Bias Post Training""" |
| 1444 | + analysis_config = { |
| 1445 | + **data_config.get_config(), |
| 1446 | + **bias_config.get_config(), |
| 1447 | + "predictor": {**model_config.get_predictor_config()}, |
| 1448 | + "methods": {"post_training_bias": {"methods": methods}}, |
| 1449 | + } |
| 1450 | + if model_predicted_label_config: |
| 1451 | + ( |
| 1452 | + probability_threshold, |
| 1453 | + predictor_config, |
| 1454 | + ) = model_predicted_label_config.get_predictor_config() |
| 1455 | + if predictor_config: |
| 1456 | + analysis_config["predictor"].update(predictor_config) |
| 1457 | + _set(probability_threshold, "probability_threshold", analysis_config) |
| 1458 | + return cls._common(analysis_config) |
| 1459 | + |
| 1460 | + @classmethod |
| 1461 | + def bias( |
| 1462 | + cls, |
| 1463 | + data_config: DataConfig, |
| 1464 | + bias_config: BiasConfig, |
| 1465 | + model_config: ModelConfig, |
| 1466 | + model_predicted_label_config: ModelPredictedLabelConfig, |
| 1467 | + pre_training_methods: Union[str, List[str]] = "all", |
| 1468 | + post_training_methods: Union[str, List[str]] = "all", |
| 1469 | + ): |
| 1470 | + """Generates a config for Bias""" |
| 1471 | + analysis_config = { |
| 1472 | + **data_config.get_config(), |
| 1473 | + **bias_config.get_config(), |
| 1474 | + "predictor": model_config.get_predictor_config(), |
| 1475 | + "methods": { |
| 1476 | + "pre_training_bias": {"methods": pre_training_methods}, |
| 1477 | + "post_training_bias": {"methods": post_training_methods}, |
| 1478 | + }, |
| 1479 | + } |
| 1480 | + if model_predicted_label_config: |
| 1481 | + ( |
| 1482 | + probability_threshold, |
| 1483 | + predictor_config, |
| 1484 | + ) = model_predicted_label_config.get_predictor_config() |
| 1485 | + if predictor_config: |
| 1486 | + analysis_config["predictor"].update(predictor_config) |
| 1487 | + _set(probability_threshold, "probability_threshold", analysis_config) |
| 1488 | + return cls._common(analysis_config) |
| 1489 | + |
| 1490 | + @staticmethod |
| 1491 | + def _common(analysis_config): |
| 1492 | + """Extends analysis config with common values""" |
| 1493 | + analysis_config["methods"]["report"] = { |
| 1494 | + "name": "report", |
| 1495 | + "title": "Analysis Report", |
| 1496 | + } |
| 1497 | + return analysis_config |
1423 | 1498 |
|
1424 | 1499 |
|
1425 | 1500 | def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
|
|
0 commit comments