Skip to content

Commit 1597bbb

Browse files
committed
fix: Support PipelineVariable ModelQualityCheckConfig attributes
1 parent 08c8a3a commit 1597bbb

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

src/sagemaker/workflow/quality_check_step.py

+35-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from abc import ABC
1718
from typing import List, Union, Optional
1819
import os
@@ -24,7 +25,8 @@
2425
from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput
2526
from sagemaker.workflow import is_pipeline_variable
2627

27-
from sagemaker.workflow.entities import RequestType, PipelineVariable
28+
from sagemaker.workflow.entities import RequestType, PipelineVariable, PrimitiveType
29+
from sagemaker.workflow.parameters import Parameter, ParameterString
2830
from sagemaker.workflow.properties import (
2931
Properties,
3032
)
@@ -47,6 +49,9 @@
4749
_DATA_QUALITY_TYPE = "DATA_QUALITY"
4850

4951

52+
logger = logging.getLogger(__name__)
53+
54+
5055
@attr.s
5156
class QualityCheckConfig(ABC):
5257
"""Quality Check Config.
@@ -407,25 +412,19 @@ def _generate_baseline_processor(
407412
post_processor_script_container_path=post_processor_script_container_path,
408413
)
409414
else:
410-
inference_attribute = (
411-
str(quality_check_cfg.inference_attribute)
412-
if quality_check_cfg.inference_attribute is not None
413-
else None
415+
inference_attribute = _format_env_variable_value(
416+
var_value=quality_check_cfg.inference_attribute, var_name="inference_attribute"
414417
)
415-
probability_attribute = (
416-
str(quality_check_cfg.probability_attribute)
417-
if quality_check_cfg.probability_attribute is not None
418-
else None
418+
probability_attribute = _format_env_variable_value(
419+
var_value=quality_check_cfg.probability_attribute, var_name="probability_attribute"
419420
)
420-
ground_truth_attribute = (
421-
str(quality_check_cfg.ground_truth_attribute)
422-
if quality_check_cfg.ground_truth_attribute is not None
423-
else None
421+
ground_truth_attribute = _format_env_variable_value(
422+
var_value=quality_check_cfg.ground_truth_attribute,
423+
var_name="ground_truth_attribute",
424424
)
425-
probability_threshold_attr = (
426-
str(quality_check_cfg.probability_threshold_attribute)
427-
if quality_check_cfg.probability_threshold_attribute is not None
428-
else None
425+
probability_threshold_attr = _format_env_variable_value(
426+
var_value=quality_check_cfg.probability_threshold_attribute,
427+
var_name="probability_threshold_attr",
429428
)
430429
normalized_env = ModelMonitor._generate_env_map(
431430
env=self._model_monitor.env,
@@ -458,3 +457,22 @@ def _generate_baseline_processor(
458457
tags=self._model_monitor.tags,
459458
network_config=self._model_monitor.network_config,
460459
)
460+
461+
462+
def _format_env_variable_value(var_value: Union[PrimitiveType, PipelineVariable], var_name: str):
463+
"""Helper function to format the variable values passed to env var
464+
465+
Args:
466+
var_value (PrimitiveType or PipelineVariable): The value of the variable.
467+
var_name (str): The name of the variable.
468+
"""
469+
if var_value is None:
470+
return None
471+
472+
if is_pipeline_variable(var_value):
473+
if isinstance(var_value, Parameter) and not isinstance(var_value, ParameterString):
474+
raise ValueError(f"{var_name} cannot be Parameter types other than ParameterString.")
475+
logger.warning("%s's runtime value must be the string type.", var_name)
476+
return var_value
477+
478+
return str(var_value)

tests/unit/sagemaker/workflow/test_quality_check_step.py

+47-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import pytest
1717

1818
from sagemaker.model_monitor import DatasetFormat
19-
from sagemaker.workflow.parameters import ParameterString
19+
from sagemaker.workflow.execution_variables import ExecutionVariable
20+
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
2021
from sagemaker.workflow.pipeline import Pipeline
2122
from sagemaker.workflow.pipeline import PipelineDefinitionConfig
2223
from sagemaker.workflow.quality_check_step import (
@@ -178,8 +179,6 @@
178179
"dataset_source": "/opt/ml/processing/input/baseline_dataset_input",
179180
"analysis_type": "MODEL_QUALITY",
180181
"problem_type": "BinaryClassification",
181-
"probability_attribute": "0",
182-
"probability_threshold_attribute": "0.5",
183182
},
184183
"StoppingCondition": {"MaxRuntimeInSeconds": 1800},
185184
},
@@ -269,23 +268,54 @@ def test_data_quality_check_step(
269268
assert step_definition == _expected_data_quality_dsl
270269

271270

271+
@pytest.mark.parametrize(
272+
"quality_cfg_attr_value, expected_value_in_dsl",
273+
[
274+
(0, "0"),
275+
("attr", "attr"),
276+
(None, None),
277+
(ParameterString(name="ParamStringEnvVar"), {"Get": "Parameters.ParamStringEnvVar"}),
278+
(ExecutionVariable("PipelineArn"), {"Get": "Execution.PipelineArn"}),
279+
(ParameterInteger(name="ParamIntEnvVar"), "Error"),
280+
],
281+
)
272282
def test_model_quality_check_step(
273283
sagemaker_session,
274284
check_job_config,
275285
model_package_group_name,
276286
supplied_baseline_statistics_uri,
277287
supplied_baseline_constraints_uri,
288+
quality_cfg_attr_value,
289+
expected_value_in_dsl,
278290
):
279291
model_quality_check_config = ModelQualityCheckConfig(
280292
baseline_dataset="baseline_dataset_s3_url",
281293
dataset_format=DatasetFormat.csv(header=True),
282294
problem_type="BinaryClassification",
283-
probability_attribute=0, # the integer should be converted to str by SDK
284-
ground_truth_attribute=None,
285-
probability_threshold_attribute=0.5, # the float should be converted to str by SDK
295+
inference_attribute=quality_cfg_attr_value,
296+
probability_attribute=quality_cfg_attr_value,
297+
ground_truth_attribute=quality_cfg_attr_value,
298+
probability_threshold_attribute=quality_cfg_attr_value,
286299
post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py",
287300
output_s3_uri="",
288301
)
302+
303+
if expected_value_in_dsl == "Error":
304+
with pytest.raises(ValueError) as err:
305+
QualityCheckStep(
306+
name="ModelQualityCheckStep",
307+
register_new_baseline=False,
308+
skip_check=False,
309+
fail_on_violation=True,
310+
quality_check_config=model_quality_check_config,
311+
check_job_config=check_job_config,
312+
model_package_group_name=model_package_group_name,
313+
supplied_baseline_statistics=supplied_baseline_statistics_uri,
314+
supplied_baseline_constraints=supplied_baseline_constraints_uri,
315+
)
316+
assert "cannot be Parameter types other than ParameterString" in str(err)
317+
return
318+
289319
model_quality_check_step = QualityCheckStep(
290320
name="ModelQualityCheckStep",
291321
register_new_baseline=False,
@@ -297,6 +327,7 @@ def test_model_quality_check_step(
297327
supplied_baseline_statistics=supplied_baseline_statistics_uri,
298328
supplied_baseline_constraints=supplied_baseline_constraints_uri,
299329
)
330+
300331
pipeline = Pipeline(
301332
name="MyPipeline",
302333
parameters=[
@@ -310,6 +341,16 @@ def test_model_quality_check_step(
310341

311342
step_definition = _get_step_definition_for_test(pipeline)
312343

344+
step_def_env = step_definition["Arguments"]["Environment"]
345+
for var in [
346+
"inference_attribute",
347+
"probability_attribute",
348+
"ground_truth_attribute",
349+
"probability_threshold_attribute",
350+
]:
351+
env_var_dsl = step_def_env.pop(var, None)
352+
assert env_var_dsl == expected_value_in_dsl
353+
313354
assert step_definition == _expected_model_quality_dsl
314355

315356

0 commit comments

Comments
 (0)