Skip to content

Commit 1b6d635

Browse files
author
Keshav Chandak
committed
Fix: updating batch transform job in monitoring schedule
1 parent c4c4e83 commit 1b6d635

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def update_monitoring_schedule(
433433
network_config=None,
434434
role=None,
435435
image_uri=None,
436+
batch_transform_input=None,
436437
):
437438
"""Updates the existing monitoring schedule.
438439
@@ -475,6 +476,8 @@ def update_monitoring_schedule(
475476
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
476477
image_uri (str): The uri of the image to use for the jobs started by
477478
the Monitor.
479+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
480+
run the monitoring schedule on the batch transform (default: None)
478481
479482
"""
480483
monitoring_inputs = None
@@ -483,6 +486,9 @@ def update_monitoring_schedule(
483486
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
484487
]
485488

489+
elif batch_transform_input is not None:
490+
monitoring_inputs = [batch_transform_input._to_request_dict()]
491+
486492
monitoring_output_config = None
487493
if output is not None:
488494
normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output)
@@ -1895,6 +1901,7 @@ def update_monitoring_schedule(
18951901
network_config=None,
18961902
enable_cloudwatch_metrics=None,
18971903
role=None,
1904+
batch_transform_input=None,
18981905
):
18991906
"""Updates the existing monitoring schedule.
19001907
@@ -1936,6 +1943,8 @@ def update_monitoring_schedule(
19361943
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
19371944
the baselining or monitoring jobs.
19381945
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
1946+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1947+
run the monitoring schedule on the batch transform (default: None)
19391948
19401949
"""
19411950
# check if this schedule is in v2 format and update as per v2 format if it is
@@ -1958,13 +1967,17 @@ def update_monitoring_schedule(
19581967
network_config=network_config,
19591968
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
19601969
role=role,
1970+
batch_transform_input=batch_transform_input,
19611971
)
19621972
return
19631973

19641974
monitoring_inputs = None
19651975
if endpoint_input is not None:
19661976
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]
19671977

1978+
elif batch_transform_input is not None:
1979+
monitoring_inputs = [batch_transform_input._to_request_dict()]
1980+
19681981
record_preprocessor_script_s3_uri = None
19691982
if record_preprocessor_script is not None:
19701983
record_preprocessor_script_s3_uri = self._s3_uri_from_local_path(

tests/integ/test_model_monitor.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,3 +1852,197 @@ def _verify_default_monitoring_schedule_with_batch_transform(
18521852
)
18531853
else:
18541854
assert network_config is None
1855+
1856+
1857+
@pytest.mark.release
1858+
def test_default_update_monitoring_batch_transform_schedule_name(
1859+
sagemaker_session, output_kms_key, volume_kms_key
1860+
):
1861+
my_default_monitor = DefaultModelMonitor(
1862+
role=ROLE,
1863+
instance_count=INSTANCE_COUNT,
1864+
instance_type=INSTANCE_TYPE,
1865+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1866+
volume_kms_key=volume_kms_key,
1867+
output_kms_key=output_kms_key,
1868+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1869+
sagemaker_session=sagemaker_session,
1870+
env=ENVIRONMENT,
1871+
tags=TAGS,
1872+
network_config=NETWORK_CONFIG,
1873+
)
1874+
1875+
output_s3_uri = os.path.join(
1876+
"s3://",
1877+
sagemaker_session.default_bucket(),
1878+
"integ-test-monitoring-output-bucket",
1879+
str(uuid.uuid4()),
1880+
)
1881+
1882+
data_captured_destination_s3_uri = os.path.join(
1883+
"s3://",
1884+
sagemaker_session.default_bucket(),
1885+
"sagemaker-serving-batch-transform",
1886+
str(uuid.uuid4()),
1887+
)
1888+
1889+
batch_transform_input = BatchTransformInput(
1890+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1891+
destination="/opt/ml/processing/output",
1892+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1893+
)
1894+
1895+
statistics = Statistics.from_file_path(
1896+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1897+
sagemaker_session=sagemaker_session,
1898+
)
1899+
1900+
constraints = Constraints.from_file_path(
1901+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1902+
sagemaker_session=sagemaker_session,
1903+
)
1904+
1905+
my_default_monitor.create_monitoring_schedule(
1906+
batch_transform_input=batch_transform_input,
1907+
output_s3_uri=output_s3_uri,
1908+
statistics=statistics,
1909+
constraints=constraints,
1910+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
1911+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
1912+
)
1913+
1914+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1915+
1916+
data_captured_destination_s3_uri = os.path.join(
1917+
"s3://",
1918+
sagemaker_session.default_bucket(),
1919+
"sagemaker-tensorflow-serving-batch-transform",
1920+
str(uuid.uuid4()),
1921+
)
1922+
1923+
batch_transform_input = BatchTransformInput(
1924+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1925+
destination="/opt/ml/processing/output",
1926+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1927+
)
1928+
1929+
my_default_monitor.update_monitoring_schedule(
1930+
batch_transform_input=batch_transform_input,
1931+
)
1932+
1933+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1934+
1935+
schedule_description = my_default_monitor.describe_schedule()
1936+
1937+
_verify_default_monitoring_schedule_with_batch_transform(
1938+
sagemaker_session=sagemaker_session,
1939+
schedule_description=schedule_description,
1940+
cron_expression=HOURLY_CRON_EXPRESSION,
1941+
statistics=statistics,
1942+
constraints=constraints,
1943+
output_kms_key=output_kms_key,
1944+
volume_kms_key=volume_kms_key,
1945+
network_config=NETWORK_CONFIG,
1946+
)
1947+
1948+
my_default_monitor.stop_monitoring_schedule()
1949+
my_default_monitor.delete_monitoring_schedule()
1950+
1951+
1952+
@pytest.mark.release
1953+
def test_byoc_monitoring_schedule_name_update_batch(
1954+
sagemaker_session, output_kms_key, volume_kms_key
1955+
):
1956+
byoc_env = ENVIRONMENT.copy()
1957+
byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False))
1958+
byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input"
1959+
byoc_env["output_path"] = os.path.join("/opt/ml/processing/output")
1960+
byoc_env["publish_cloudwatch_metrics"] = "Disabled"
1961+
1962+
my_byoc_monitor = ModelMonitor(
1963+
role=ROLE,
1964+
image_uri=DefaultModelMonitor._get_default_image_uri(
1965+
sagemaker_session.boto_session.region_name
1966+
),
1967+
instance_count=INSTANCE_COUNT,
1968+
instance_type=INSTANCE_TYPE,
1969+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1970+
volume_kms_key=volume_kms_key,
1971+
output_kms_key=output_kms_key,
1972+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1973+
sagemaker_session=sagemaker_session,
1974+
env=byoc_env,
1975+
tags=TAGS,
1976+
network_config=NETWORK_CONFIG,
1977+
)
1978+
1979+
output_s3_uri = os.path.join(
1980+
"s3://",
1981+
sagemaker_session.default_bucket(),
1982+
"integ-test-monitoring-output-bucket",
1983+
str(uuid.uuid4()),
1984+
)
1985+
1986+
statistics = Statistics.from_file_path(
1987+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1988+
sagemaker_session=sagemaker_session,
1989+
)
1990+
1991+
constraints = Constraints.from_file_path(
1992+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1993+
sagemaker_session=sagemaker_session,
1994+
)
1995+
1996+
data_captured_destination_s3_uri = os.path.join(
1997+
"s3://",
1998+
sagemaker_session.default_bucket(),
1999+
"sagemaker-serving-batch-transform",
2000+
str(uuid.uuid4()),
2001+
)
2002+
2003+
batch_transform_input = BatchTransformInput(
2004+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2005+
destination="/opt/ml/processing/output",
2006+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2007+
)
2008+
2009+
my_byoc_monitor.create_monitoring_schedule(
2010+
endpoint_input=batch_transform_input,
2011+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
2012+
statistics=statistics,
2013+
constraints=constraints,
2014+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
2015+
)
2016+
2017+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2018+
2019+
data_captured_destination_s3_uri = os.path.join(
2020+
"s3://",
2021+
sagemaker_session.default_bucket(),
2022+
"sagemaker-tensorflow-serving-batch-transform",
2023+
str(uuid.uuid4()),
2024+
)
2025+
2026+
batch_transform_input = BatchTransformInput(
2027+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2028+
destination="/opt/ml/processing/output",
2029+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2030+
)
2031+
2032+
my_byoc_monitor.update_monitoring_schedule(
2033+
batch_transform_input=batch_transform_input,
2034+
)
2035+
2036+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2037+
2038+
schedule_description = my_byoc_monitor.describe_schedule()
2039+
2040+
assert (
2041+
data_captured_destination_s3_uri
2042+
== schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
2043+
"MonitoringInputs"
2044+
][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"]
2045+
)
2046+
2047+
my_byoc_monitor.stop_monitoring_schedule()
2048+
my_byoc_monitor.delete_monitoring_schedule()

0 commit comments

Comments
 (0)