Skip to content

Commit 4f1d329

Browse files
navinsoniqidewenwhen
authored andcommitted
feature: Add support for ModelMonitor/Clarify integration in model building pipelines
* feature: Add support for QualityCheckStep and ClarifyCheckStep in model building pipelines * feature: Adding Baseline metrics for model packages Co-authored-by: qidewenwhen <[email protected]>
1 parent 639b3ef commit 4f1d329

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+4485
-42
lines changed

doc/api/inference/model_monitor.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,13 @@ Model Monitor
3131
:members:
3232
:undoc-members:
3333
:show-inheritance:
34+
35+
.. automodule:: sagemaker.model_metrics
36+
:members:
37+
:undoc-members:
38+
:show-inheritance:
39+
40+
.. automodule:: sagemaker.drift_check_baselines
41+
:members:
42+
:undoc-members:
43+
:show-inheritance:

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ Conditions
3232

3333
.. autoclass:: sagemaker.workflow.conditions.ConditionOr
3434

35+
CheckJobConfig
36+
--------------
37+
38+
.. autoclass:: sagemaker.workflow.check_job_config.CheckJobConfig
39+
3540
Entities
3641
--------
3742

@@ -128,3 +133,13 @@ Steps
128133
.. autoclass:: sagemaker.workflow.steps.CacheConfig
129134

130135
.. autoclass:: sagemaker.workflow.lambda_step.LambdaStep
136+
137+
.. autoclass:: sagemaker.workflow.steps.CompilationStep
138+
139+
.. autoclass:: sagemaker.workflow.quality_check_step.QualityCheckConfig
140+
141+
.. autoclass:: sagemaker.workflow.quality_check_step.QualityCheckStep
142+
143+
.. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckConfig
144+
145+
.. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckStep

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from sagemaker.local.local_session import LocalSession # noqa: F401
5050

5151
from sagemaker.model import Model, ModelPackage # noqa: F401
52-
from sagemaker.model_metrics import ModelMetrics, MetricsSource # noqa: F401
52+
from sagemaker.model_metrics import ModelMetrics, MetricsSource, FileSource # noqa: F401
5353
from sagemaker.pipeline import PipelineModel # noqa: F401
5454
from sagemaker.predictor import Predictor # noqa: F401
5555
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401

src/sagemaker/clarify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5959
"ShardedByS3Key".
6060
s3_compression_type (str): Valid options are "None" or "Gzip".
61-
joinsource (str): The name or index of the column in the dataset that acts an
61+
joinsource (str): The name or index of the column in the dataset that acts as an
6262
identifier column (for instance, while performing a join). This column is only
6363
used as an identifier, and not used for any other computations. This is an
6464
optional field in all cases except when the dataset contains more than one file,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This file contains code related to drift check baselines"""
14+
from __future__ import absolute_import
15+
16+
17+
class DriftCheckBaselines(object):
18+
"""Accepts drift check baselines parameters for conversion to request dict."""
19+
20+
def __init__(
21+
self,
22+
model_statistics=None,
23+
model_constraints=None,
24+
model_data_statistics=None,
25+
model_data_constraints=None,
26+
bias_config_file=None,
27+
bias_pre_training_constraints=None,
28+
bias_post_training_constraints=None,
29+
explainability_constraints=None,
30+
explainability_config_file=None,
31+
):
32+
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
33+
34+
Args:
35+
model_statistics (MetricsSource): A metric source object that represents
36+
model statistics (default: None).
37+
model_constraints (MetricsSource): A metric source object that represents
38+
model constraints (default: None).
39+
model_data_statistics (MetricsSource): A metric source object that represents
40+
model data statistics (default: None).
41+
model_data_constraints (MetricsSource): A metric source object that represents
42+
model data constraints (default: None).
43+
bias_config_file (FileSource): A file source object that represents bias config
44+
(default: None).
45+
bias_pre_training_constraints (MetricsSource):
46+
A metric source object that represents Pre-training constraints (default: None).
47+
bias_post_training_constraints (MetricsSource):
48+
A metric source object that represents Post-training constraits (default: None).
49+
explainability_constraints (MetricsSource):
50+
A metric source object that represents explainability constraints (default: None).
51+
explainability_config_file (FileSource): A file source object that represents
52+
explainability config (default: None).
53+
"""
54+
self.model_statistics = model_statistics
55+
self.model_constraints = model_constraints
56+
self.model_data_statistics = model_data_statistics
57+
self.model_data_constraints = model_data_constraints
58+
self.bias_config_file = bias_config_file
59+
self.bias_pre_training_constraints = bias_pre_training_constraints
60+
self.bias_post_training_constraints = bias_post_training_constraints
61+
self.explainability_constraints = explainability_constraints
62+
self.explainability_config_file = explainability_config_file
63+
64+
def _to_request_dict(self):
65+
"""Generates a request dictionary using the parameters provided to the class."""
66+
drift_check_baselines_request = {}
67+
68+
model_quality = {}
69+
if self.model_statistics is not None:
70+
model_quality["Statistics"] = self.model_statistics._to_request_dict()
71+
if self.model_constraints is not None:
72+
model_quality["Constraints"] = self.model_constraints._to_request_dict()
73+
if model_quality:
74+
drift_check_baselines_request["ModelQuality"] = model_quality
75+
76+
model_data_quality = {}
77+
if self.model_data_statistics is not None:
78+
model_data_quality["Statistics"] = self.model_data_statistics._to_request_dict()
79+
if self.model_data_constraints is not None:
80+
model_data_quality["Constraints"] = self.model_data_constraints._to_request_dict()
81+
if model_data_quality:
82+
drift_check_baselines_request["ModelDataQuality"] = model_data_quality
83+
84+
bias = {}
85+
if self.bias_config_file is not None:
86+
bias["ConfigFile"] = self.bias_config_file._to_request_dict()
87+
if self.bias_pre_training_constraints is not None:
88+
bias["PreTrainingConstraints"] = self.bias_pre_training_constraints._to_request_dict()
89+
if self.bias_post_training_constraints is not None:
90+
bias["PostTrainingConstraints"] = self.bias_post_training_constraints._to_request_dict()
91+
if bias:
92+
drift_check_baselines_request["Bias"] = bias
93+
94+
explainability = {}
95+
if self.explainability_constraints is not None:
96+
explainability["Constraints"] = self.explainability_constraints._to_request_dict()
97+
if self.explainability_config_file is not None:
98+
explainability["ConfigFile"] = self.explainability_config_file._to_request_dict()
99+
if explainability:
100+
drift_check_baselines_request["Explainability"] = explainability
101+
102+
return drift_check_baselines_request

src/sagemaker/estimator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ def register(
977977
description=None,
978978
compile_model_family=None,
979979
model_name=None,
980+
drift_check_baselines=None,
980981
**kwargs,
981982
):
982983
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1005,6 +1006,7 @@ def register(
10051006
compile_model_family (str): Instance family for compiled model, if specified, a compiled
10061007
model will be used (default: None).
10071008
model_name (str): User defined model name (default: None).
1009+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
10081010
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
10091011
``create_model()`` to accept ``**kwargs`` to customize model creation during
10101012
deploy. For more, see the implementation docs.
@@ -1034,6 +1036,7 @@ def register(
10341036
marketplace_cert,
10351037
approval_status,
10361038
description,
1039+
drift_check_baselines=drift_check_baselines,
10371040
)
10381041

10391042
@property
@@ -1920,7 +1923,13 @@ def training_image_uri(self):
19201923
return self.image_uri
19211924

19221925
def set_hyperparameters(self, **kwargs):
1923-
"""Placeholder docstring"""
1926+
"""Sets the hyperparameter dictionary to use for training.
1927+
1928+
The hyperparameters are made accessible as a dict[str, str] to the
1929+
training code on SageMaker. For convenience, this accepts other types
1930+
for keys and values, but ``str()`` will be called to convert them before
1931+
training.
1932+
"""
19241933
for k, v in kwargs.items():
19251934
self.hyperparam_dict[k] = v
19261935

src/sagemaker/huggingface/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def register(
186186
marketplace_cert=False,
187187
approval_status=None,
188188
description=None,
189+
drift_check_baselines=None,
189190
):
190191
"""Creates a model package for creating SageMaker models or listing on Marketplace.
191192
@@ -212,6 +213,7 @@ def register(
212213
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
213214
or "PendingManualApproval". Defaults to ``PendingManualApproval``.
214215
description (str): Model Package description. Defaults to ``None``.
216+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
215217
216218
Returns:
217219
A `sagemaker.model.ModelPackage` instance.
@@ -239,6 +241,7 @@ def register(
239241
marketplace_cert,
240242
approval_status,
241243
description,
244+
drift_check_baselines=drift_check_baselines,
242245
)
243246

244247
def prepare_container_def(self, instance_type=None, accelerator_type=None):

src/sagemaker/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def register(
146146
marketplace_cert=False,
147147
approval_status=None,
148148
description=None,
149+
drift_check_baselines=None,
149150
):
150151
"""Creates a model package for creating SageMaker models or listing on Marketplace.
151152
@@ -170,6 +171,7 @@ def register(
170171
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
171172
or "PendingManualApproval" (default: "PendingManualApproval").
172173
description (str): Model Package description (default: None).
174+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
173175
174176
Returns:
175177
A `sagemaker.model.ModelPackage` instance.
@@ -191,6 +193,7 @@ def register(
191193
marketplace_cert,
192194
approval_status,
193195
description,
196+
drift_check_baselines=drift_check_baselines,
194197
)
195198
model_package = self.sagemaker_session.create_model_package_from_containers(
196199
**model_pkg_args

src/sagemaker/model_metrics.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This file contains code related to model metrics, including metric source."""
13+
"""This file contains code related to model metrics, including metric source and file source."""
1414
from __future__ import absolute_import
1515

1616

@@ -25,22 +25,36 @@ def __init__(
2525
model_data_constraints=None,
2626
bias=None,
2727
explainability=None,
28+
bias_pre_training=None,
29+
bias_post_training=None,
2830
):
2931
"""Initialize a ``ModelMetrics`` instance and turn parameters into dict.
3032
31-
# TODO: flesh out docstrings
3233
Args:
33-
model_constraints (MetricsSource):
34-
model_data_constraints (MetricsSource):
35-
model_data_statistics (MetricsSource):
36-
bias (MetricsSource):
37-
explainability (MetricsSource):
34+
model_statistics (MetricsSource): A metric source object that represents
35+
model statistics (default: None).
36+
model_constraints (MetricsSource): A metric source object that represents
37+
model constraints (default: None).
38+
model_data_statistics (MetricsSource): A metric source object that represents
39+
model data statistics (default: None).
40+
model_data_constraints (MetricsSource): A metric source object that represents
41+
model data constraints (default: None).
42+
bias (MetricsSource): A metric source object that represents bias report
43+
(default: None).
44+
explainability (MetricsSource): A metric source object that represents
45+
explainability report (default: None).
46+
bias_pre_training (MetricsSource): A metric source object that represents
47+
Pre-training report (default: None).
48+
bias_post_training (MetricsSource): A metric source object that represents
49+
Post-training report (default: None).
3850
"""
3951
self.model_statistics = model_statistics
4052
self.model_constraints = model_constraints
4153
self.model_data_statistics = model_data_statistics
4254
self.model_data_constraints = model_data_constraints
4355
self.bias = bias
56+
self.bias_pre_training = bias_pre_training
57+
self.bias_post_training = bias_post_training
4458
self.explainability = explainability
4559

4660
def _to_request_dict(self):
@@ -63,10 +77,20 @@ def _to_request_dict(self):
6377
if model_data_quality:
6478
model_metrics_request["ModelDataQuality"] = model_data_quality
6579

80+
bias = {}
6681
if self.bias is not None:
67-
model_metrics_request["Bias"] = self.bias._to_request_dict()
82+
bias["Report"] = self.bias._to_request_dict()
83+
if self.bias_pre_training is not None:
84+
bias["PreTrainingReport"] = self.bias_pre_training._to_request_dict()
85+
if self.bias_post_training is not None:
86+
bias["PostTrainingReport"] = self.bias_post_training._to_request_dict()
87+
model_metrics_request["Bias"] = bias
88+
89+
explainability = {}
6890
if self.explainability is not None:
69-
model_metrics_request["Explainability"] = self.explainability._to_request_dict()
91+
explainability["Report"] = self.explainability._to_request_dict()
92+
model_metrics_request["Explainability"] = explainability
93+
7094
return model_metrics_request
7195

7296

@@ -81,11 +105,10 @@ def __init__(
81105
):
82106
"""Initialize a ``MetricsSource`` instance and turn parameters into dict.
83107
84-
# TODO: flesh out docstrings
85108
Args:
86-
content_type (str):
87-
s3_uri (str):
88-
content_digest (str):
109+
content_type (str): Specifies the type of content in S3 URI
110+
s3_uri (str): The S3 URI of the metric
111+
content_digest (str): The digest of the metric (default: None)
89112
"""
90113
self.content_type = content_type
91114
self.s3_uri = s3_uri
@@ -97,3 +120,33 @@ def _to_request_dict(self):
97120
if self.content_digest is not None:
98121
metrics_source_request["ContentDigest"] = self.content_digest
99122
return metrics_source_request
123+
124+
125+
class FileSource(object):
126+
"""Accepts file source parameters for conversion to request dict."""
127+
128+
def __init__(
129+
self,
130+
s3_uri,
131+
content_digest=None,
132+
content_type=None,
133+
):
134+
"""Initialize a ``FileSource`` instance and turn parameters into dict.
135+
136+
Args:
137+
s3_uri (str): The S3 URI of the metric
138+
content_digest (str): The digest of the metric (default: None)
139+
content_type (str): Specifies the type of content in S3 URI (default: None)
140+
"""
141+
self.content_type = content_type
142+
self.s3_uri = s3_uri
143+
self.content_digest = content_digest
144+
145+
def _to_request_dict(self):
146+
"""Generates a request dictionary using the parameters provided to the class."""
147+
file_source_request = {"S3Uri": self.s3_uri}
148+
if self.content_digest is not None:
149+
file_source_request["ContentDigest"] = self.content_digest
150+
if self.content_type is not None:
151+
file_source_request["ContentType"] = self.content_type
152+
return file_source_request

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def register(
157157
marketplace_cert=False,
158158
approval_status=None,
159159
description=None,
160+
drift_check_baselines=None,
160161
):
161162
"""Creates a model package for creating SageMaker models or listing on Marketplace.
162163
@@ -181,6 +182,7 @@ def register(
181182
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
182183
or "PendingManualApproval" (default: "PendingManualApproval").
183184
description (str): Model Package description (default: None).
185+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
184186
185187
Returns:
186188
A `sagemaker.model.ModelPackage` instance.
@@ -208,6 +210,7 @@ def register(
208210
marketplace_cert,
209211
approval_status,
210212
description,
213+
drift_check_baselines=drift_check_baselines,
211214
)
212215

213216
def prepare_container_def(self, instance_type=None, accelerator_type=None):

0 commit comments

Comments
 (0)