Skip to content

Commit 555e0b7

Browse files
keshav-chandakKeshav Chandak
and
Keshav Chandak
authored
feature: added monitor batch transform step (pipeline) (#3398)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent e6ceef0 commit 555e0b7

File tree

3 files changed

+1159
-1
lines changed

3 files changed

+1159
-1
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The `MonitorBatchTransform` definition for SageMaker Pipelines Workflows"""
14+
from __future__ import absolute_import
15+
import logging
16+
from typing import Union, Optional, List
17+
18+
from sagemaker.session import Session
19+
from sagemaker.workflow.pipeline_context import _JobStepArguments
20+
from sagemaker.workflow.entities import PipelineVariable
21+
from sagemaker.workflow.step_collections import StepCollection
22+
from sagemaker.workflow.quality_check_step import (
23+
QualityCheckStep,
24+
QualityCheckConfig,
25+
DataQualityCheckConfig,
26+
)
27+
from sagemaker.workflow.clarify_check_step import (
28+
ClarifyCheckStep,
29+
ClarifyCheckConfig,
30+
ModelExplainabilityCheckConfig,
31+
)
32+
from sagemaker.workflow.steps import Step
33+
from sagemaker.workflow.check_job_config import CheckJobConfig
34+
from sagemaker.workflow.steps import TransformStep
35+
from sagemaker.workflow.utilities import validate_step_args_input
36+
37+
38+
class MonitorBatchTransformStep(StepCollection):
39+
"""Creates a Transformer step with Quality or Clarify check step
40+
41+
Used to monitor the inputs and outputs of the batch transform job.
42+
"""
43+
44+
def __init__(
45+
self,
46+
name: str,
47+
transform_step_args: _JobStepArguments,
48+
monitor_configuration: Union[QualityCheckConfig, ClarifyCheckConfig],
49+
check_job_configuration: CheckJobConfig,
50+
monitor_before_transform: bool = False,
51+
fail_on_violation: Union[bool, PipelineVariable] = True,
52+
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
53+
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
54+
display_name: Optional[str] = None,
55+
description: Optional[str] = None,
56+
):
57+
"""Construct a step collection of `TransformStep`, `QualityCheckStep` or `ClarifyCheckStep`
58+
59+
Args:
60+
name (str): The name of the `MonitorBatchTransformStep`.
61+
The corresponding transform step will be named `{name}-transform`;
62+
and the corresponding check step will be named `{name}-monitoring`
63+
transform_step_args (_JobStepArguments): the transform step transform arguments.
64+
monitor_configuration (Union[
65+
`sagemaker.workflow.quality_check_step.QualityCheckConfig`,
66+
`sagemaker.workflow.quality_check_step.ClarifyCheckConfig`
67+
]): the monitoring configuration used for run model monitoring.
68+
check_job_configuration (`sagemaker.workflow.check_job_config.CheckJobConfig`):
69+
the check job (processing job) cluster resource configuration.
70+
monitor_before_transform (bool): If to run data quality or model explainability
71+
monitoring type, a true value of this flag indicates
72+
running the check step before the transform job.
73+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
74+
check step when a violation is detected.
75+
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
76+
to the supplied statistics object representing the statistics JSON file
77+
which will be used for drift to check (default: None).
78+
supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path
79+
to the supplied constraints object representing the constraints JSON file
80+
which will be used for drift to check (default: None).
81+
display_name (str): The display name of the `MonitorBatchTransformStep`.
82+
The display name provides better UI readability.
83+
The corresponding transform step will be
84+
named `{display_name}-transform`; and the corresponding check step
85+
will be named `{display_name}-monitoring` (default: None).
86+
description (str): The description of the `MonitorBatchTransformStep` (default: None).
87+
"""
88+
self.name: str = name
89+
self.steps: List[Step] = []
90+
91+
validate_step_args_input(
92+
step_args=transform_step_args,
93+
expected_caller={
94+
Session.transform.__name__,
95+
},
96+
error_message="The transform_step_args of MonitorBatchTransformStep"
97+
"must be obtained from transformer.transform()",
98+
)
99+
transform_step = TransformStep(
100+
name=f"{name}-transform",
101+
display_name=f"{display_name}-transform" if display_name else None,
102+
description=description,
103+
step_args=transform_step_args,
104+
)
105+
106+
self.steps.append(transform_step)
107+
108+
monitoring_step_name = f"{name}-monitoring"
109+
monitoring_step_display_name = f"{display_name}-monitoring" if display_name else None
110+
if isinstance(monitor_configuration, QualityCheckConfig):
111+
monitoring_step = QualityCheckStep(
112+
name=monitoring_step_name,
113+
display_name=monitoring_step_display_name,
114+
description=description,
115+
quality_check_config=monitor_configuration,
116+
check_job_config=check_job_configuration,
117+
skip_check=False,
118+
supplied_baseline_statistics=supplied_baseline_statistics,
119+
supplied_baseline_constraints=supplied_baseline_constraints,
120+
fail_on_violation=fail_on_violation,
121+
)
122+
elif isinstance(monitor_configuration, ClarifyCheckConfig):
123+
if supplied_baseline_statistics:
124+
logging.warning(
125+
"supplied_baseline_statistics will be ignored if monitor_configuration "
126+
"is a ClarifyCheckConfig"
127+
)
128+
monitoring_step = ClarifyCheckStep(
129+
name=monitoring_step_name,
130+
display_name=monitoring_step_display_name,
131+
description=description,
132+
clarify_check_config=monitor_configuration,
133+
check_job_config=check_job_configuration,
134+
skip_check=False,
135+
supplied_baseline_constraints=supplied_baseline_constraints,
136+
fail_on_violation=fail_on_violation,
137+
)
138+
else:
139+
raise ValueError(
140+
f"Unrecognized monitoring configuration: {monitor_configuration}"
141+
f"Should be an instance of either QualityCheckConfig or ClarifyCheckConfig"
142+
)
143+
144+
self.steps.append(monitoring_step)
145+
146+
if monitor_before_transform and not (
147+
isinstance(
148+
monitor_configuration, (DataQualityCheckConfig, ModelExplainabilityCheckConfig)
149+
)
150+
):
151+
raise ValueError(
152+
"monitor_before_transform only take effect when the monitor_configuration "
153+
"is one of [DataQualityCheckConfig, ModelExplainabilityCheckConfig]"
154+
)
155+
156+
if monitor_before_transform:
157+
transform_step.add_depends_on([monitoring_step])
158+
else:
159+
monitoring_step.add_depends_on([transform_step])

tests/data/xgboost_abalone/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def output_fn(predictions, content_type):
5454
"""
5555
After invoking predict_fn, the model server invokes `output_fn`.
5656
"""
57-
if content_type == "text/csv":
57+
if content_type == "text/csv" or content_type == "application/json":
5858
return ",".join(str(x) for x in predictions[0])
5959
else:
6060
raise ValueError("Content type {} is not supported.".format(content_type))

0 commit comments

Comments
 (0)