Skip to content

Commit 0427aa7

Browse files
keshav-chandakKeshav Chandak
authored and
Namrata Madan
committed
feature: Added transform with monitoring pipeline step in transformer (aws#3438)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 26b2369 commit 0427aa7

File tree

2 files changed

+220
-4
lines changed

2 files changed

+220
-4
lines changed

src/sagemaker/transformer.py

+155-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Union, Optional, List, Dict
17-
from botocore import exceptions
17+
import logging
18+
import copy
19+
import time
1820

21+
from botocore import exceptions
1922
from sagemaker.job import _Job
20-
from sagemaker.session import Session
23+
from sagemaker.session import Session, get_execution_role
2124
from sagemaker.inputs import BatchDataCaptureConfig
2225
from sagemaker.workflow.entities import PipelineVariable
2326
from sagemaker.workflow.functions import Join
24-
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
27+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
2528
from sagemaker.workflow import is_pipeline_variable
2629
from sagemaker.workflow.execution_variables import ExecutionVariables
2730
from sagemaker.utils import base_name_from_image, name_from_base
@@ -266,6 +269,155 @@ def transform(
266269
if wait:
267270
self.latest_transform_job.wait(logs=logs)
268271

272+
def transform_with_monitoring(
273+
self,
274+
monitoring_config,
275+
monitoring_resource_config,
276+
data: str,
277+
data_type: str = "S3Prefix",
278+
content_type: str = None,
279+
compression_type: str = None,
280+
split_type: str = None,
281+
input_filter: str = None,
282+
output_filter: str = None,
283+
join_source: str = None,
284+
model_client_config: Dict[str, str] = None,
285+
batch_data_capture_config: BatchDataCaptureConfig = None,
286+
monitor_before_transform: bool = False,
287+
supplied_baseline_statistics: str = None,
288+
supplied_baseline_constraints: str = None,
289+
wait: bool = True,
290+
pipeline_name: str = None,
291+
role: str = None,
292+
):
293+
"""Runs a transform job with monitoring job.
294+
295+
Note that this function will not start a transform job immediately,
296+
instead, it will create a SageMaker Pipeline and execute it.
297+
If you provide an existing pipeline_name, no new pipeline will be created, otherwise,
298+
each transform_with_monitoring call will create a new pipeline and execute.
299+
300+
Args:
301+
monitoring_config (Union[
302+
`sagemaker.workflow.quality_check_step.QualityCheckConfig`,
303+
`sagemaker.workflow.quality_check_step.ClarifyCheckConfig`
304+
]): the monitoring configuration used for run model monitoring.
305+
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
306+
the check job (processing job) cluster resource configuration.
307+
transform_step_args (_JobStepArguments): the transform step transform arguments.
308+
data (str): Input data location in S3 for the transform job
309+
data_type (str): What the S3 location defines (default: 'S3Prefix').
310+
Valid values:
311+
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix
312+
will be used as inputs for the transform job.
313+
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3
314+
object to use as an input for the transform job.
315+
content_type (str): MIME type of the input data (default: None).
316+
compression_type (str): Compression type of the input data, if
317+
compressed (default: None). Valid values: 'Gzip', None.
318+
split_type (str): The record delimiter for the input object
319+
(default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and
320+
'TFRecord'.
321+
input_filter (str): A JSONPath to select a portion of the input to
322+
pass to the algorithm container for inference. If you omit the
323+
field, it gets the value '$', representing the entire input.
324+
For CSV data, each row is taken as a JSON array,
325+
so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
326+
CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
327+
See `Supported JSONPath Operators
328+
<https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
329+
for a table of supported JSONPath operators.
330+
For more information, see the SageMaker API documentation for
331+
`CreateTransformJob
332+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
333+
Some examples: "$[1:]", "$.features" (default: None).
334+
output_filter (str): A JSONPath to select a portion of the
335+
joined/original output to return as the output.
336+
For more information, see the SageMaker API documentation for
337+
`CreateTransformJob
338+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
339+
Some examples: "$[1:]", "$.prediction" (default: None).
340+
join_source (str): The source of data to be joined to the transform
341+
output. It can be set to 'Input' meaning the entire input record
342+
will be joined to the inference result. You can use OutputFilter
343+
to select the useful portion before uploading to S3. (default:
344+
None). Valid values: Input, None.
345+
model_client_config (dict[str, str]): Model configuration.
346+
Dictionary contains two optional keys,
347+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
348+
(default: ``None``).
349+
batch_data_capture_config (BatchDataCaptureConfig): Configuration object which
350+
specifies the configurations related to the batch data capture for the transform job
351+
(default: ``None``).
352+
monitor_before_transform (bgool): If to run data quality
353+
or model explainability monitoring type,
354+
a true value of this flag indicates running the check step before the transform job.
355+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
356+
check step when a violation is detected.
357+
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
358+
to the supplied statistics object representing the statistics JSON file
359+
which will be used for drift to check (default: None).
360+
supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path
361+
to the supplied constraints object representing the constraints JSON file
362+
which will be used for drift to check (default: None).
363+
wait (bool): To determine if needed to wait for the pipeline execution to complete
364+
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
365+
role (str): Execution role
366+
"""
367+
368+
transformer = self
369+
if not isinstance(self.sagemaker_session, PipelineSession):
370+
sagemaker_session = self.sagemaker_session
371+
self.sagemaker_session = None
372+
transformer = copy.deepcopy(self)
373+
transformer.sagemaker_session = PipelineSession()
374+
self.sagemaker_session = sagemaker_session
375+
376+
transform_step_args = transformer.transform(
377+
data=data,
378+
data_type=data_type,
379+
content_type=content_type,
380+
compression_type=compression_type,
381+
split_type=split_type,
382+
input_filter=input_filter,
383+
output_filter=output_filter,
384+
batch_data_capture_config=batch_data_capture_config,
385+
join_source=join_source,
386+
model_client_config=model_client_config,
387+
)
388+
389+
from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
390+
391+
monitoring_batch_step = MonitorBatchTransformStep(
392+
name="MonitorBatchTransformStep",
393+
display_name="MonitorBatchTransformStep",
394+
description="",
395+
transform_step_args=transform_step_args,
396+
monitor_configuration=monitoring_config,
397+
check_job_configuration=monitoring_resource_config,
398+
monitor_before_transform=monitor_before_transform,
399+
supplied_baseline_constraints=supplied_baseline_constraints,
400+
supplied_baseline_statistics=supplied_baseline_statistics,
401+
)
402+
403+
pipeline_name = (
404+
pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}"
405+
)
406+
# if pipeline exists, just start the execution
407+
from sagemaker.workflow.pipeline import Pipeline
408+
409+
pipeline = Pipeline(
410+
name=pipeline_name,
411+
steps=[monitoring_batch_step],
412+
sagemaker_session=transformer.sagemaker_session,
413+
)
414+
pipeline.upsert(role_arn=role if role else get_execution_role())
415+
execution = pipeline.start()
416+
if wait:
417+
logging.info("Waiting for transform with monitoring to execute ...")
418+
execution.wait()
419+
return execution
420+
269421
def delete_model(self):
270422
"""Delete the corresponding SageMaker model for this Transformer."""
271423
self.sagemaker_session.delete_model(self.model_name)

tests/integ/test_transformer.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.transformer import Transformer
2626
from sagemaker.estimator import Estimator
2727
from sagemaker.inputs import BatchDataCaptureConfig
28+
from sagemaker.xgboost import XGBoostModel
2829
from sagemaker.utils import unique_name_from_base
2930
from tests.integ import (
3031
datasets,
@@ -36,7 +37,7 @@
3637
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
3738
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
3839

39-
from sagemaker.model_monitor import DatasetFormat, Statistics
40+
from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints
4041

4142
from sagemaker.workflow.check_job_config import CheckJobConfig
4243
from sagemaker.workflow.quality_check_step import (
@@ -645,3 +646,66 @@ def _create_transformer_and_transform_job(
645646
job_name=unique_name_from_base("test-transform"),
646647
)
647648
return transformer
649+
650+
651+
def test_transformer_and_monitoring_job(
652+
pipeline_session,
653+
sagemaker_session,
654+
role,
655+
pipeline_name,
656+
check_job_config,
657+
data_bias_check_config,
658+
):
659+
xgb_model_data_s3 = pipeline_session.upload_data(
660+
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
661+
key_prefix="integ-test-data/xgboost/model",
662+
)
663+
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
664+
constraints_file_path=os.path.join(
665+
DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json"
666+
),
667+
sagemaker_session=sagemaker_session,
668+
).file_s3_uri
669+
670+
xgb_model = XGBoostModel(
671+
model_data=xgb_model_data_s3,
672+
framework_version="1.3-1",
673+
role=role,
674+
sagemaker_session=sagemaker_session,
675+
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
676+
enable_network_isolation=True,
677+
)
678+
679+
xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
680+
681+
transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
682+
transformer = Transformer(
683+
model_name=xgb_model.name,
684+
strategy="SingleRecord",
685+
instance_type="ml.m5.xlarge",
686+
instance_count=1,
687+
output_path=transform_output,
688+
sagemaker_session=pipeline_session,
689+
)
690+
691+
transform_input = pipeline_session.upload_data(
692+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
693+
key_prefix="integ-test-data/xgboost_abalone/abalone",
694+
)
695+
696+
execution = transformer.transform_with_monitoring(
697+
monitoring_config=data_bias_check_config,
698+
monitoring_resource_config=check_job_config,
699+
data=transform_input,
700+
content_type="text/libsvm",
701+
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
702+
role=role,
703+
)
704+
705+
execution_steps = execution.list_steps()
706+
assert len(execution_steps) == 2
707+
708+
for execution_step in execution_steps:
709+
assert execution_step["StepStatus"] == "Succeeded"
710+
711+
xgb_model.delete_model()

0 commit comments

Comments
 (0)