|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | from typing import Union, Optional, List, Dict
|
17 |
| -from botocore import exceptions |
| 17 | +import logging |
| 18 | +import copy |
| 19 | +import time |
18 | 20 |
|
| 21 | +from botocore import exceptions |
19 | 22 | from sagemaker.job import _Job
|
20 |
| -from sagemaker.session import Session |
| 23 | +from sagemaker.session import Session, get_execution_role |
21 | 24 | from sagemaker.inputs import BatchDataCaptureConfig
|
22 | 25 | from sagemaker.workflow.entities import PipelineVariable
|
23 | 26 | from sagemaker.workflow.functions import Join
|
24 |
| -from sagemaker.workflow.pipeline_context import runnable_by_pipeline |
| 27 | +from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession |
25 | 28 | from sagemaker.workflow import is_pipeline_variable
|
26 | 29 | from sagemaker.workflow.execution_variables import ExecutionVariables
|
27 | 30 | from sagemaker.utils import base_name_from_image, name_from_base
|
@@ -266,6 +269,155 @@ def transform(
|
266 | 269 | if wait:
|
267 | 270 | self.latest_transform_job.wait(logs=logs)
|
268 | 271 |
|
| 272 | + def transform_with_monitoring( |
| 273 | + self, |
| 274 | + monitoring_config, |
| 275 | + monitoring_resource_config, |
| 276 | + data: str, |
| 277 | + data_type: str = "S3Prefix", |
| 278 | + content_type: str = None, |
| 279 | + compression_type: str = None, |
| 280 | + split_type: str = None, |
| 281 | + input_filter: str = None, |
| 282 | + output_filter: str = None, |
| 283 | + join_source: str = None, |
| 284 | + model_client_config: Dict[str, str] = None, |
| 285 | + batch_data_capture_config: BatchDataCaptureConfig = None, |
| 286 | + monitor_before_transform: bool = False, |
| 287 | + supplied_baseline_statistics: str = None, |
| 288 | + supplied_baseline_constraints: str = None, |
| 289 | + wait: bool = True, |
| 290 | + pipeline_name: str = None, |
| 291 | + role: str = None, |
| 292 | + ): |
| 293 | + """Runs a transform job with monitoring job. |
| 294 | +
|
| 295 | + Note that this function will not start a transform job immediately, |
| 296 | + instead, it will create a SageMaker Pipeline and execute it. |
| 297 | + If you provide an existing pipeline_name, no new pipeline will be created, otherwise, |
| 298 | + each transform_with_monitoring call will create a new pipeline and execute. |
| 299 | +
|
| 300 | + Args: |
| 301 | + monitoring_config (Union[ |
| 302 | + `sagemaker.workflow.quality_check_step.QualityCheckConfig`, |
| 303 | + `sagemaker.workflow.quality_check_step.ClarifyCheckConfig` |
| 304 | + ]): the monitoring configuration used for run model monitoring. |
| 305 | + monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): |
| 306 | + the check job (processing job) cluster resource configuration. |
| 307 | + transform_step_args (_JobStepArguments): the transform step transform arguments. |
| 308 | + data (str): Input data location in S3 for the transform job |
| 309 | + data_type (str): What the S3 location defines (default: 'S3Prefix'). |
| 310 | + Valid values: |
| 311 | + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix |
| 312 | + will be used as inputs for the transform job. |
| 313 | + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 |
| 314 | + object to use as an input for the transform job. |
| 315 | + content_type (str): MIME type of the input data (default: None). |
| 316 | + compression_type (str): Compression type of the input data, if |
| 317 | + compressed (default: None). Valid values: 'Gzip', None. |
| 318 | + split_type (str): The record delimiter for the input object |
| 319 | + (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and |
| 320 | + 'TFRecord'. |
| 321 | + input_filter (str): A JSONPath to select a portion of the input to |
| 322 | + pass to the algorithm container for inference. If you omit the |
| 323 | + field, it gets the value '$', representing the entire input. |
| 324 | + For CSV data, each row is taken as a JSON array, |
| 325 | + so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. |
| 326 | + CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_. |
| 327 | + See `Supported JSONPath Operators |
| 328 | + <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_ |
| 329 | + for a table of supported JSONPath operators. |
| 330 | + For more information, see the SageMaker API documentation for |
| 331 | + `CreateTransformJob |
| 332 | + <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_. |
| 333 | + Some examples: "$[1:]", "$.features" (default: None). |
| 334 | + output_filter (str): A JSONPath to select a portion of the |
| 335 | + joined/original output to return as the output. |
| 336 | + For more information, see the SageMaker API documentation for |
| 337 | + `CreateTransformJob |
| 338 | + <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_. |
| 339 | + Some examples: "$[1:]", "$.prediction" (default: None). |
| 340 | + join_source (str): The source of data to be joined to the transform |
| 341 | + output. It can be set to 'Input' meaning the entire input record |
| 342 | + will be joined to the inference result. You can use OutputFilter |
| 343 | + to select the useful portion before uploading to S3. (default: |
| 344 | + None). Valid values: Input, None. |
| 345 | + model_client_config (dict[str, str]): Model configuration. |
| 346 | + Dictionary contains two optional keys, |
| 347 | + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. |
| 348 | + (default: ``None``). |
| 349 | + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which |
| 350 | + specifies the configurations related to the batch data capture for the transform job |
| 351 | + (default: ``None``). |
| 352 | + monitor_before_transform (bgool): If to run data quality |
| 353 | + or model explainability monitoring type, |
| 354 | + a true value of this flag indicates running the check step before the transform job. |
| 355 | + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the |
| 356 | + check step when a violation is detected. |
| 357 | + supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path |
| 358 | + to the supplied statistics object representing the statistics JSON file |
| 359 | + which will be used for drift to check (default: None). |
| 360 | + supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path |
| 361 | + to the supplied constraints object representing the constraints JSON file |
| 362 | + which will be used for drift to check (default: None). |
| 363 | + wait (bool): To determine if needed to wait for the pipeline execution to complete |
| 364 | + pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step |
| 365 | + role (str): Execution role |
| 366 | + """ |
| 367 | + |
| 368 | + transformer = self |
| 369 | + if not isinstance(self.sagemaker_session, PipelineSession): |
| 370 | + sagemaker_session = self.sagemaker_session |
| 371 | + self.sagemaker_session = None |
| 372 | + transformer = copy.deepcopy(self) |
| 373 | + transformer.sagemaker_session = PipelineSession() |
| 374 | + self.sagemaker_session = sagemaker_session |
| 375 | + |
| 376 | + transform_step_args = transformer.transform( |
| 377 | + data=data, |
| 378 | + data_type=data_type, |
| 379 | + content_type=content_type, |
| 380 | + compression_type=compression_type, |
| 381 | + split_type=split_type, |
| 382 | + input_filter=input_filter, |
| 383 | + output_filter=output_filter, |
| 384 | + batch_data_capture_config=batch_data_capture_config, |
| 385 | + join_source=join_source, |
| 386 | + model_client_config=model_client_config, |
| 387 | + ) |
| 388 | + |
| 389 | + from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep |
| 390 | + |
| 391 | + monitoring_batch_step = MonitorBatchTransformStep( |
| 392 | + name="MonitorBatchTransformStep", |
| 393 | + display_name="MonitorBatchTransformStep", |
| 394 | + description="", |
| 395 | + transform_step_args=transform_step_args, |
| 396 | + monitor_configuration=monitoring_config, |
| 397 | + check_job_configuration=monitoring_resource_config, |
| 398 | + monitor_before_transform=monitor_before_transform, |
| 399 | + supplied_baseline_constraints=supplied_baseline_constraints, |
| 400 | + supplied_baseline_statistics=supplied_baseline_statistics, |
| 401 | + ) |
| 402 | + |
| 403 | + pipeline_name = ( |
| 404 | + pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}" |
| 405 | + ) |
| 406 | + # if pipeline exists, just start the execution |
| 407 | + from sagemaker.workflow.pipeline import Pipeline |
| 408 | + |
| 409 | + pipeline = Pipeline( |
| 410 | + name=pipeline_name, |
| 411 | + steps=[monitoring_batch_step], |
| 412 | + sagemaker_session=transformer.sagemaker_session, |
| 413 | + ) |
| 414 | + pipeline.upsert(role_arn=role if role else get_execution_role()) |
| 415 | + execution = pipeline.start() |
| 416 | + if wait: |
| 417 | + logging.info("Waiting for transform with monitoring to execute ...") |
| 418 | + execution.wait() |
| 419 | + return execution |
| 420 | + |
269 | 421 | def delete_model(self):
|
270 | 422 | """Delete the corresponding SageMaker model for this Transformer."""
|
271 | 423 | self.sagemaker_session.delete_model(self.model_name)
|
|
0 commit comments