Skip to content

Commit b3ef11d

Browse files
author
Keshav Chandak
committed
feature: Added transform with monitoring pipeline step in transformer
1 parent 1c55297 commit b3ef11d

File tree

2 files changed

+220
-4
lines changed

2 files changed

+220
-4
lines changed

src/sagemaker/transformer.py

+155-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Union, Optional, List, Dict
17-
from botocore import exceptions
17+
import logging
18+
import copy
19+
import time
1820

21+
from botocore import exceptions
1922
from sagemaker.job import _Job
20-
from sagemaker.session import Session
23+
from sagemaker.session import Session, get_execution_role
2124
from sagemaker.inputs import BatchDataCaptureConfig
2225
from sagemaker.workflow.entities import PipelineVariable
23-
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
26+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
2427
from sagemaker.workflow import is_pipeline_variable
2528
from sagemaker.utils import base_name_from_image, name_from_base
2629

@@ -247,6 +250,155 @@ def transform(
247250
if wait:
248251
self.latest_transform_job.wait(logs=logs)
249252

253+
def transform_with_monitoring(
254+
self,
255+
monitoring_config,
256+
monitoring_resource_config,
257+
data: str,
258+
data_type: str = "S3Prefix",
259+
content_type: str = None,
260+
compression_type: str = None,
261+
split_type: str = None,
262+
input_filter: str = None,
263+
output_filter: str = None,
264+
join_source: str = None,
265+
model_client_config: Dict[str, str] = None,
266+
batch_data_capture_config: BatchDataCaptureConfig = None,
267+
monitor_before_transform: bool = False,
268+
supplied_baseline_statistics: str = None,
269+
supplied_baseline_constraints: str = None,
270+
wait: bool = True,
271+
pipeline_name: str = None,
272+
role: str = None,
273+
):
274+
"""Runs a transform job with monitoring job.
275+
276+
Note that this function will not start a transform job immediately,
277+
instead, it will create a SageMaker Pipeline and execute it.
278+
If you provide an existing pipeline_name, no new pipeline will be created, otherwise,
279+
each transform_with_monitoring call will create a new pipeline and execute.
280+
281+
Args:
282+
monitoring_config (Union[
283+
`sagemaker.workflow.quality_check_step.QualityCheckConfig`,
284+
`sagemaker.workflow.quality_check_step.ClarifyCheckConfig`
285+
]): the monitoring configuration used for run model monitoring.
286+
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
287+
the check job (processing job) cluster resource configuration.
288+
transform_step_args (_JobStepArguments): the transform step transform arguments.
289+
data (str): Input data location in S3 for the transform job
290+
data_type (str): What the S3 location defines (default: 'S3Prefix').
291+
Valid values:
292+
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix
293+
will be used as inputs for the transform job.
294+
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3
295+
object to use as an input for the transform job.
296+
content_type (str): MIME type of the input data (default: None).
297+
compression_type (str): Compression type of the input data, if
298+
compressed (default: None). Valid values: 'Gzip', None.
299+
split_type (str): The record delimiter for the input object
300+
(default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and
301+
'TFRecord'.
302+
input_filter (str): A JSONPath to select a portion of the input to
303+
pass to the algorithm container for inference. If you omit the
304+
field, it gets the value '$', representing the entire input.
305+
For CSV data, each row is taken as a JSON array,
306+
so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
307+
CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
308+
See `Supported JSONPath Operators
309+
<https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
310+
for a table of supported JSONPath operators.
311+
For more information, see the SageMaker API documentation for
312+
`CreateTransformJob
313+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
314+
Some examples: "$[1:]", "$.features" (default: None).
315+
output_filter (str): A JSONPath to select a portion of the
316+
joined/original output to return as the output.
317+
For more information, see the SageMaker API documentation for
318+
`CreateTransformJob
319+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
320+
Some examples: "$[1:]", "$.prediction" (default: None).
321+
join_source (str): The source of data to be joined to the transform
322+
output. It can be set to 'Input' meaning the entire input record
323+
will be joined to the inference result. You can use OutputFilter
324+
to select the useful portion before uploading to S3. (default:
325+
None). Valid values: Input, None.
326+
model_client_config (dict[str, str]): Model configuration.
327+
Dictionary contains two optional keys,
328+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
329+
(default: ``None``).
330+
batch_data_capture_config (BatchDataCaptureConfig): Configuration object which
331+
specifies the configurations related to the batch data capture for the transform job
332+
(default: ``None``).
333+
monitor_before_transform (bgool): If to run data quality
334+
or model explainability monitoring type,
335+
a true value of this flag indicates running the check step before the transform job.
336+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
337+
check step when a violation is detected.
338+
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
339+
to the supplied statistics object representing the statistics JSON file
340+
which will be used for drift to check (default: None).
341+
supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path
342+
to the supplied constraints object representing the constraints JSON file
343+
which will be used for drift to check (default: None).
344+
wait (bool): To determine if needed to wait for the pipeline execution to complete
345+
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
346+
role (str): Execution role
347+
"""
348+
349+
transformer = self
350+
if not isinstance(self.sagemaker_session, PipelineSession):
351+
sagemaker_session = self.sagemaker_session
352+
self.sagemaker_session = None
353+
transformer = copy.deepcopy(self)
354+
transformer.sagemaker_session = PipelineSession()
355+
self.sagemaker_session = sagemaker_session
356+
357+
transform_step_args = transformer.transform(
358+
data=data,
359+
data_type=data_type,
360+
content_type=content_type,
361+
compression_type=compression_type,
362+
split_type=split_type,
363+
input_filter=input_filter,
364+
output_filter=output_filter,
365+
batch_data_capture_config=batch_data_capture_config,
366+
join_source=join_source,
367+
model_client_config=model_client_config,
368+
)
369+
370+
from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
371+
372+
monitoring_batch_step = MonitorBatchTransformStep(
373+
name="MonitorBatchTransformStep",
374+
display_name="MonitorBatchTransformStep",
375+
description="",
376+
transform_step_args=transform_step_args,
377+
monitor_configuration=monitoring_config,
378+
check_job_configuration=monitoring_resource_config,
379+
monitor_before_transform=monitor_before_transform,
380+
supplied_baseline_constraints=supplied_baseline_constraints,
381+
supplied_baseline_statistics=supplied_baseline_statistics,
382+
)
383+
384+
pipeline_name = (
385+
pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}"
386+
)
387+
# if pipeline exists, just start the execution
388+
from sagemaker.workflow.pipeline import Pipeline
389+
390+
pipeline = Pipeline(
391+
name=pipeline_name,
392+
steps=[monitoring_batch_step],
393+
sagemaker_session=transformer.sagemaker_session,
394+
)
395+
pipeline.upsert(role_arn=role if role else get_execution_role())
396+
execution = pipeline.start()
397+
if wait:
398+
logging.info("Waiting for transform with monitoring to execute ...")
399+
execution.wait()
400+
return execution
401+
250402
def delete_model(self):
251403
"""Delete the corresponding SageMaker model for this Transformer."""
252404
self.sagemaker_session.delete_model(self.model_name)

tests/integ/test_transformer.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.transformer import Transformer
2626
from sagemaker.estimator import Estimator
2727
from sagemaker.inputs import BatchDataCaptureConfig
28+
from sagemaker.xgboost import XGBoostModel
2829
from sagemaker.utils import unique_name_from_base
2930
from tests.integ import (
3031
datasets,
@@ -36,7 +37,7 @@
3637
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
3738
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
3839

39-
from sagemaker.model_monitor import DatasetFormat, Statistics
40+
from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints
4041

4142
from sagemaker.workflow.check_job_config import CheckJobConfig
4243
from sagemaker.workflow.quality_check_step import (
@@ -645,3 +646,66 @@ def _create_transformer_and_transform_job(
645646
job_name=unique_name_from_base("test-transform"),
646647
)
647648
return transformer
649+
650+
651+
def test_transformer_and_monitoring_job(
652+
pipeline_session,
653+
sagemaker_session,
654+
role,
655+
pipeline_name,
656+
check_job_config,
657+
data_bias_check_config,
658+
):
659+
xgb_model_data_s3 = pipeline_session.upload_data(
660+
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
661+
key_prefix="integ-test-data/xgboost/model",
662+
)
663+
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
664+
constraints_file_path=os.path.join(
665+
DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json"
666+
),
667+
sagemaker_session=sagemaker_session,
668+
).file_s3_uri
669+
670+
xgb_model = XGBoostModel(
671+
model_data=xgb_model_data_s3,
672+
framework_version="1.3-1",
673+
role=role,
674+
sagemaker_session=sagemaker_session,
675+
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
676+
enable_network_isolation=True,
677+
)
678+
679+
xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
680+
681+
transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
682+
transformer = Transformer(
683+
model_name=xgb_model.name,
684+
strategy="SingleRecord",
685+
instance_type="ml.m5.xlarge",
686+
instance_count=1,
687+
output_path=transform_output,
688+
sagemaker_session=pipeline_session,
689+
)
690+
691+
transform_input = pipeline_session.upload_data(
692+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
693+
key_prefix="integ-test-data/xgboost_abalone/abalone",
694+
)
695+
696+
execution = transformer.transform_with_monitoring(
697+
monitoring_config=data_bias_check_config,
698+
monitoring_resource_config=check_job_config,
699+
data=transform_input,
700+
content_type="text/libsvm",
701+
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
702+
role=role,
703+
)
704+
705+
execution_steps = execution.list_steps()
706+
assert len(execution_steps) == 2
707+
708+
for execution_step in execution_steps:
709+
assert execution_step["StepStatus"] == "Succeeded"
710+
711+
xgb_model.delete_model()

0 commit comments

Comments
 (0)