Skip to content

Commit bef1e58

Browse files
Merge branch 'master' into mm-batch-support-on-demand
2 parents 480dcf6 + b27423d commit bef1e58

24 files changed

+710
-228
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ There are a number of properties for a pipeline execution that can only be resol
741741
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_EXECUTION_ARN`: The execution ARN for an execution.
742742
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_NAME`: The name of the pipeline.
743743
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_ARN`: The ARN of the pipeline.
744+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.TRAINING_JOB_NAME`: The name of the training job launched by the training step.
745+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PROCESSING_JOB_NAME`: The name of the processing job launched by the processing step.
744746
745747
You can use these execution variables as you see fit. The following example uses the :code:`START_DATETIME` execution variable to construct a processing output path:
746748

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Execution Variables
5252
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariable
5353

5454
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariables
55-
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN
55+
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN, TRAINING_JOB_NAME, PROCESSING_JOB_NAME
5656

5757
Functions
5858
---------

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import tempfile
19+
from typing import Union
1920

2021
from six.moves.urllib.parse import urlparse
2122

@@ -27,6 +28,7 @@
2728
from sagemaker.estimator import EstimatorBase, _TrainingJob
2829
from sagemaker.inputs import FileSystemInput, TrainingInput
2930
from sagemaker.utils import sagemaker_timestamp
31+
from sagemaker.workflow.entities import PipelineVariable
3032
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3133

3234
logger = logging.getLogger(__name__)
@@ -304,7 +306,12 @@ class RecordSet(object):
304306
"""Placeholder docstring"""
305307

306308
def __init__(
307-
self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train"
309+
self,
310+
s3_data: Union[str, PipelineVariable],
311+
num_records: int,
312+
feature_dim: int,
313+
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
314+
channel: Union[str, PipelineVariable] = "train",
308315
):
309316
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
310317

src/sagemaker/clarify.py

Lines changed: 140 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28+
from typing import List, Union
29+
2830
from sagemaker import image_uris, s3, utils
2931
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3032

@@ -63,7 +65,6 @@ def __init__(
6365
label (str): Target attribute of the model required by bias metrics.
6466
Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
6567
*Required parameter* except for when the input dataset does not contain the label.
66-
Cannot be used at the same time as ``predicted_label``.
6768
features (str): JSONPath for locating the feature columns for bias metrics if the
6869
dataset format is JSONLines.
6970
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
@@ -103,7 +104,7 @@ def __init__(
103104
predicted_label (str or int): Predicted label of the target attribute of the model
104105
required for running bias analysis. Specified as column name or index for CSV data.
105106
Clarify uses the predicted labels directly instead of making model inference API
106-
calls. Cannot be used at the same time as ``label``.
107+
calls.
107108
excluded_columns (list[int] or list[str]): A list of names or indices of the columns
108109
which are to be excluded from making model inference API calls.
109110
@@ -922,6 +923,7 @@ def __init__(
922923
version (str): Clarify version to use.
923924
""" # noqa E501 # pylint: disable=c0301
924925
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
926+
self._last_analysis_config = None
925927
self.job_name_prefix = job_name_prefix
926928
super(SageMakerClarifyProcessor, self).__init__(
927929
role,
@@ -983,10 +985,10 @@ def _run(
983985
the Trial Component will be unassociated.
984986
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
985987
"""
986-
analysis_config["methods"]["report"] = {
987-
"name": "report",
988-
"title": "Analysis Report",
989-
}
988+
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
989+
self._last_analysis_config = analysis_config
990+
logger.info("Analysis Config: %s", analysis_config)
991+
990992
with tempfile.TemporaryDirectory() as tmpdirname:
991993
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
992994
with open(analysis_config_file, "w") as f:
@@ -1083,14 +1085,13 @@ def run_pre_training_bias(
10831085
the Trial Component will be unassociated.
10841086
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
10851087
""" # noqa E501 # pylint: disable=c0301
1086-
analysis_config = data_config.get_config()
1087-
analysis_config.update(data_bias_config.get_config())
1088-
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1089-
if job_name is None:
1090-
if self.job_name_prefix:
1091-
job_name = utils.name_from_base(self.job_name_prefix)
1092-
else:
1093-
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
1088+
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
1089+
data_config, data_bias_config, methods
1090+
)
1091+
# when name is either not provided (is None) or an empty string ("")
1092+
job_name = job_name or utils.name_from_base(
1093+
self.job_name_prefix or "Clarify-Pretraining-Bias"
1094+
)
10941095
return self._run(
10951096
data_config,
10961097
analysis_config,
@@ -1165,21 +1166,13 @@ def run_post_training_bias(
11651166
the Trial Component will be unassociated.
11661167
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11671168
""" # noqa E501 # pylint: disable=c0301
1168-
analysis_config = data_config.get_config()
1169-
analysis_config.update(data_bias_config.get_config())
1170-
(
1171-
probability_threshold,
1172-
predictor_config,
1173-
) = model_predicted_label_config.get_predictor_config()
1174-
predictor_config.update(model_config.get_predictor_config())
1175-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1176-
analysis_config["predictor"] = predictor_config
1177-
_set(probability_threshold, "probability_threshold", analysis_config)
1178-
if job_name is None:
1179-
if self.job_name_prefix:
1180-
job_name = utils.name_from_base(self.job_name_prefix)
1181-
else:
1182-
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
1169+
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1170+
data_config, data_bias_config, model_predicted_label_config, methods, model_config
1171+
)
1172+
# when name is either not provided (is None) or an empty string ("")
1173+
job_name = job_name or utils.name_from_base(
1174+
self.job_name_prefix or "Clarify-Posttraining-Bias"
1175+
)
11831176
return self._run(
11841177
data_config,
11851178
analysis_config,
@@ -1264,28 +1257,16 @@ def run_bias(
12641257
the Trial Component will be unassociated.
12651258
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
12661259
""" # noqa E501 # pylint: disable=c0301
1267-
analysis_config = data_config.get_config()
1268-
analysis_config.update(bias_config.get_config())
1269-
analysis_config["predictor"] = model_config.get_predictor_config()
1270-
if model_predicted_label_config:
1271-
(
1272-
probability_threshold,
1273-
predictor_config,
1274-
) = model_predicted_label_config.get_predictor_config()
1275-
if predictor_config:
1276-
analysis_config["predictor"].update(predictor_config)
1277-
if probability_threshold is not None:
1278-
analysis_config["probability_threshold"] = probability_threshold
1279-
1280-
analysis_config["methods"] = {
1281-
"pre_training_bias": {"methods": pre_training_methods},
1282-
"post_training_bias": {"methods": post_training_methods},
1283-
}
1284-
if job_name is None:
1285-
if self.job_name_prefix:
1286-
job_name = utils.name_from_base(self.job_name_prefix)
1287-
else:
1288-
job_name = utils.name_from_base("Clarify-Bias")
1260+
analysis_config = _AnalysisConfigGenerator.bias(
1261+
data_config,
1262+
bias_config,
1263+
model_config,
1264+
model_predicted_label_config,
1265+
pre_training_methods,
1266+
post_training_methods,
1267+
)
1268+
# when name is either not provided (is None) or an empty string ("")
1269+
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias")
12891270
return self._run(
12901271
data_config,
12911272
analysis_config,
@@ -1370,6 +1351,36 @@ def run_explainability(
13701351
the Trial Component will be unassociated.
13711352
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
13721353
""" # noqa E501 # pylint: disable=c0301
1354+
analysis_config = _AnalysisConfigGenerator.explainability(
1355+
data_config, model_config, model_scores, explainability_config
1356+
)
1357+
# when name is either not provided (is None) or an empty string ("")
1358+
job_name = job_name or utils.name_from_base(
1359+
self.job_name_prefix or "Clarify-Explainability"
1360+
)
1361+
return self._run(
1362+
data_config,
1363+
analysis_config,
1364+
wait,
1365+
logs,
1366+
job_name,
1367+
kms_key,
1368+
experiment_config,
1369+
)
1370+
1371+
1372+
class _AnalysisConfigGenerator:
1373+
"""Creates analysis_config objects for different type of runs."""
1374+
1375+
@classmethod
1376+
def explainability(
1377+
cls,
1378+
data_config: DataConfig,
1379+
model_config: ModelConfig,
1380+
model_scores: ModelPredictedLabelConfig,
1381+
explainability_config: ExplainabilityConfig,
1382+
):
1383+
"""Generates a config for Explainability"""
13731384
analysis_config = data_config.get_config()
13741385
predictor_config = model_config.get_predictor_config()
13751386
if isinstance(model_scores, ModelPredictedLabelConfig):
@@ -1406,20 +1417,84 @@ def run_explainability(
14061417
explainability_methods = explainability_config.get_explainability_config()
14071418
analysis_config["methods"] = explainability_methods
14081419
analysis_config["predictor"] = predictor_config
1409-
if job_name is None:
1410-
if self.job_name_prefix:
1411-
job_name = utils.name_from_base(self.job_name_prefix)
1412-
else:
1413-
job_name = utils.name_from_base("Clarify-Explainability")
1414-
return self._run(
1415-
data_config,
1416-
analysis_config,
1417-
wait,
1418-
logs,
1419-
job_name,
1420-
kms_key,
1421-
experiment_config,
1422-
)
1420+
return cls._common(analysis_config)
1421+
1422+
@classmethod
1423+
def bias_pre_training(
1424+
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
1425+
):
1426+
"""Generates a config for Bias Pre Training"""
1427+
analysis_config = {
1428+
**data_config.get_config(),
1429+
**bias_config.get_config(),
1430+
"methods": {"pre_training_bias": {"methods": methods}},
1431+
}
1432+
return cls._common(analysis_config)
1433+
1434+
@classmethod
1435+
def bias_post_training(
1436+
cls,
1437+
data_config: DataConfig,
1438+
bias_config: BiasConfig,
1439+
model_predicted_label_config: ModelPredictedLabelConfig,
1440+
methods: Union[str, List[str]],
1441+
model_config: ModelConfig,
1442+
):
1443+
"""Generates a config for Bias Post Training"""
1444+
analysis_config = {
1445+
**data_config.get_config(),
1446+
**bias_config.get_config(),
1447+
"predictor": {**model_config.get_predictor_config()},
1448+
"methods": {"post_training_bias": {"methods": methods}},
1449+
}
1450+
if model_predicted_label_config:
1451+
(
1452+
probability_threshold,
1453+
predictor_config,
1454+
) = model_predicted_label_config.get_predictor_config()
1455+
if predictor_config:
1456+
analysis_config["predictor"].update(predictor_config)
1457+
_set(probability_threshold, "probability_threshold", analysis_config)
1458+
return cls._common(analysis_config)
1459+
1460+
@classmethod
1461+
def bias(
1462+
cls,
1463+
data_config: DataConfig,
1464+
bias_config: BiasConfig,
1465+
model_config: ModelConfig,
1466+
model_predicted_label_config: ModelPredictedLabelConfig,
1467+
pre_training_methods: Union[str, List[str]] = "all",
1468+
post_training_methods: Union[str, List[str]] = "all",
1469+
):
1470+
"""Generates a config for Bias"""
1471+
analysis_config = {
1472+
**data_config.get_config(),
1473+
**bias_config.get_config(),
1474+
"predictor": model_config.get_predictor_config(),
1475+
"methods": {
1476+
"pre_training_bias": {"methods": pre_training_methods},
1477+
"post_training_bias": {"methods": post_training_methods},
1478+
},
1479+
}
1480+
if model_predicted_label_config:
1481+
(
1482+
probability_threshold,
1483+
predictor_config,
1484+
) = model_predicted_label_config.get_predictor_config()
1485+
if predictor_config:
1486+
analysis_config["predictor"].update(predictor_config)
1487+
_set(probability_threshold, "probability_threshold", analysis_config)
1488+
return cls._common(analysis_config)
1489+
1490+
@staticmethod
1491+
def _common(analysis_config):
1492+
"""Extends analysis config with common values"""
1493+
analysis_config["methods"]["report"] = {
1494+
"name": "report",
1495+
"title": "Analysis Report",
1496+
}
1497+
return analysis_config
14231498

14241499

14251500
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

0 commit comments

Comments
 (0)