|
29 | 29 | BatchTransformInput,
|
30 | 30 | ModelQualityMonitor,
|
31 | 31 | Statistics,
|
| 32 | + MonitoringOutput, |
32 | 33 | )
|
33 | 34 | from sagemaker.model_monitor.monitoring_alert import (
|
34 | 35 | MonitoringAlertSummary,
|
|
132 | 133 | POSTPROCESSOR_URI = "s3://my_bucket/postprocessor.py"
|
133 | 134 | DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/"
|
134 | 135 | DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False)
|
| 136 | +ARGUMENTS = ["test-argument"] |
135 | 137 | JOB_OUTPUT_CONFIG = {
|
136 | 138 | "MonitoringOutputs": [
|
137 | 139 | {
|
|
372 | 374 | },
|
373 | 375 | "GroundTruthS3Input": {"S3Uri": NEW_GROUND_TRUTH_S3_URI},
|
374 | 376 | }
|
| 377 | +MODEL_QUALITY_MONITOR_JOB_INPUT = { |
| 378 | + "EndpointInput": { |
| 379 | + "EndpointName": ENDPOINT_NAME, |
| 380 | + "LocalPath": ENDPOINT_INPUT_LOCAL_PATH, |
| 381 | + "S3InputMode": S3_INPUT_MODE, |
| 382 | + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, |
| 383 | + "StartTimeOffset": START_TIME_OFFSET, |
| 384 | + "EndTimeOffset": END_TIME_OFFSET, |
| 385 | + "FeaturesAttribute": FEATURES_ATTRIBUTE, |
| 386 | + "InferenceAttribute": INFERENCE_ATTRIBUTE, |
| 387 | + "ProbabilityAttribute": PROBABILITY_ATTRIBUTE, |
| 388 | + "ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE, |
| 389 | + }, |
| 390 | +} |
375 | 391 | NEW_MODEL_QUALITY_JOB_DEFINITION = {
|
376 | 392 | "ModelQualityAppSpecification": NEW_MODEL_QUALITY_APP_SPECIFICATION,
|
377 | 393 | "ModelQualityJobInput": NEW_MODEL_QUALITY_JOB_INPUT,
|
@@ -483,6 +499,25 @@ def data_quality_monitor(sagemaker_session):
|
483 | 499 | )
|
484 | 500 |
|
485 | 501 |
|
| 502 | +@pytest.fixture() |
| 503 | +def model_monitor_arguments(sagemaker_session): |
| 504 | + return ModelMonitor( |
| 505 | + role=ROLE, |
| 506 | + instance_count=INSTANCE_COUNT, |
| 507 | + instance_type=INSTANCE_TYPE, |
| 508 | + volume_size_in_gb=VOLUME_SIZE_IN_GB, |
| 509 | + volume_kms_key=VOLUME_KMS_KEY, |
| 510 | + output_kms_key=OUTPUT_KMS_KEY, |
| 511 | + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, |
| 512 | + base_job_name=BASE_JOB_NAME, |
| 513 | + sagemaker_session=sagemaker_session, |
| 514 | + env=ENVIRONMENT, |
| 515 | + tags=TAGS, |
| 516 | + network_config=NETWORK_CONFIG, |
| 517 | + image_uri=DefaultModelMonitor._get_default_image_uri(REGION), |
| 518 | + ) |
| 519 | + |
| 520 | + |
486 | 521 | def test_default_model_monitor_suggest_baseline(sagemaker_session):
|
487 | 522 | my_default_monitor = DefaultModelMonitor(
|
488 | 523 | role=ROLE,
|
@@ -1691,3 +1726,54 @@ def test_batch_transform_and_endpoint_input_simultaneous_failure(
|
1691 | 1726 | )
|
1692 | 1727 | except Exception as e:
|
1693 | 1728 | assert "Need to have either batch_transform_input or endpoint_input" in str(e)
|
| 1729 | + |
| 1730 | + |
| 1731 | +def test_model_monitor_with_arguments( |
| 1732 | + model_monitor_arguments, |
| 1733 | + sagemaker_session, |
| 1734 | + constraints=None, |
| 1735 | + statistics=None, |
| 1736 | + endpoint_input=EndpointInput( |
| 1737 | + endpoint_name=ENDPOINT_NAME, |
| 1738 | + destination=ENDPOINT_INPUT_LOCAL_PATH, |
| 1739 | + start_time_offset=START_TIME_OFFSET, |
| 1740 | + end_time_offset=END_TIME_OFFSET, |
| 1741 | + features_attribute=FEATURES_ATTRIBUTE, |
| 1742 | + inference_attribute=INFERENCE_ATTRIBUTE, |
| 1743 | + probability_attribute=PROBABILITY_ATTRIBUTE, |
| 1744 | + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, |
| 1745 | + ), |
| 1746 | +): |
| 1747 | + # for batch transform input |
| 1748 | + model_monitor_arguments.create_monitoring_schedule( |
| 1749 | + constraints=constraints, |
| 1750 | + statistics=statistics, |
| 1751 | + monitor_schedule_name=SCHEDULE_NAME, |
| 1752 | + schedule_cron_expression=CRON_HOURLY, |
| 1753 | + endpoint_input=endpoint_input, |
| 1754 | + arguments=ARGUMENTS, |
| 1755 | + output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI), |
| 1756 | + ) |
| 1757 | + |
| 1758 | + sagemaker_session.create_monitoring_schedule.assert_called_with( |
| 1759 | + monitoring_schedule_name="schedule", |
| 1760 | + schedule_expression="cron(0 * ? * * *)", |
| 1761 | + statistics_s3_uri=None, |
| 1762 | + constraints_s3_uri=None, |
| 1763 | + monitoring_inputs=[MODEL_QUALITY_MONITOR_JOB_INPUT], |
| 1764 | + monitoring_output_config=JOB_OUTPUT_CONFIG, |
| 1765 | + instance_count=INSTANCE_COUNT, |
| 1766 | + instance_type=INSTANCE_TYPE, |
| 1767 | + volume_size_in_gb=VOLUME_SIZE_IN_GB, |
| 1768 | + volume_kms_key=VOLUME_KMS_KEY, |
| 1769 | + image_uri=DEFAULT_IMAGE_URI, |
| 1770 | + entrypoint=None, |
| 1771 | + arguments=ARGUMENTS, |
| 1772 | + record_preprocessor_source_uri=None, |
| 1773 | + post_analytics_processor_source_uri=None, |
| 1774 | + max_runtime_in_seconds=3, |
| 1775 | + environment=ENVIRONMENT, |
| 1776 | + network_config=NETWORK_CONFIG._to_request_dict(), |
| 1777 | + role_arn=ROLE, |
| 1778 | + tags=TAGS, |
| 1779 | + ) |
0 commit comments