Skip to content

fix: Support PipelineVariable for ModelQualityCheckConfig attributes #4353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions src/sagemaker/workflow/quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

import logging
from abc import ABC
from typing import List, Union, Optional
import os
Expand All @@ -24,7 +25,8 @@
from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput
from sagemaker.workflow import is_pipeline_variable

from sagemaker.workflow.entities import RequestType, PipelineVariable
from sagemaker.workflow.entities import RequestType, PipelineVariable, PrimitiveType
from sagemaker.workflow.parameters import Parameter, ParameterString
from sagemaker.workflow.properties import (
Properties,
)
Expand All @@ -47,6 +49,9 @@
_DATA_QUALITY_TYPE = "DATA_QUALITY"


logger = logging.getLogger(__name__)


@attr.s
class QualityCheckConfig(ABC):
"""Quality Check Config.
Expand Down Expand Up @@ -407,25 +412,19 @@ def _generate_baseline_processor(
post_processor_script_container_path=post_processor_script_container_path,
)
else:
inference_attribute = (
str(quality_check_cfg.inference_attribute)
if quality_check_cfg.inference_attribute is not None
else None
inference_attribute = _format_env_variable_value(
var_value=quality_check_cfg.inference_attribute, var_name="inference_attribute"
)
probability_attribute = (
str(quality_check_cfg.probability_attribute)
if quality_check_cfg.probability_attribute is not None
else None
probability_attribute = _format_env_variable_value(
var_value=quality_check_cfg.probability_attribute, var_name="probability_attribute"
)
ground_truth_attribute = (
str(quality_check_cfg.ground_truth_attribute)
if quality_check_cfg.ground_truth_attribute is not None
else None
ground_truth_attribute = _format_env_variable_value(
var_value=quality_check_cfg.ground_truth_attribute,
var_name="ground_truth_attribute",
)
probability_threshold_attr = (
str(quality_check_cfg.probability_threshold_attribute)
if quality_check_cfg.probability_threshold_attribute is not None
else None
probability_threshold_attr = _format_env_variable_value(
var_value=quality_check_cfg.probability_threshold_attribute,
var_name="probability_threshold_attr",
)
normalized_env = ModelMonitor._generate_env_map(
env=self._model_monitor.env,
Expand Down Expand Up @@ -458,3 +457,22 @@ def _generate_baseline_processor(
tags=self._model_monitor.tags,
network_config=self._model_monitor.network_config,
)


def _format_env_variable_value(var_value: Union[PrimitiveType, PipelineVariable], var_name: str):
"""Helper function to format the variable values passed to env var

Args:
var_value (PrimitiveType or PipelineVariable): The value of the variable.
var_name (str): The name of the variable.
"""
if var_value is None:
return None

if is_pipeline_variable(var_value):
if isinstance(var_value, Parameter) and not isinstance(var_value, ParameterString):
raise ValueError(f"{var_name} cannot be Parameter types other than ParameterString.")
logger.warning("%s's runtime value must be the string type.", var_name)
return var_value

return str(var_value)
53 changes: 47 additions & 6 deletions tests/unit/sagemaker/workflow/test_quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import pytest

from sagemaker.model_monitor import DatasetFormat
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline import PipelineDefinitionConfig
from sagemaker.workflow.quality_check_step import (
Expand Down Expand Up @@ -178,8 +179,6 @@
"dataset_source": "/opt/ml/processing/input/baseline_dataset_input",
"analysis_type": "MODEL_QUALITY",
"problem_type": "BinaryClassification",
"probability_attribute": "0",
"probability_threshold_attribute": "0.5",
},
"StoppingCondition": {"MaxRuntimeInSeconds": 1800},
},
Expand Down Expand Up @@ -269,23 +268,54 @@ def test_data_quality_check_step(
assert step_definition == _expected_data_quality_dsl


@pytest.mark.parametrize(
"quality_cfg_attr_value, expected_value_in_dsl",
[
(0, "0"),
("attr", "attr"),
(None, None),
(ParameterString(name="ParamStringEnvVar"), {"Get": "Parameters.ParamStringEnvVar"}),
(ExecutionVariable("PipelineArn"), {"Get": "Execution.PipelineArn"}),
(ParameterInteger(name="ParamIntEnvVar"), "Error"),
],
)
def test_model_quality_check_step(
sagemaker_session,
check_job_config,
model_package_group_name,
supplied_baseline_statistics_uri,
supplied_baseline_constraints_uri,
quality_cfg_attr_value,
expected_value_in_dsl,
):
model_quality_check_config = ModelQualityCheckConfig(
baseline_dataset="baseline_dataset_s3_url",
dataset_format=DatasetFormat.csv(header=True),
problem_type="BinaryClassification",
probability_attribute=0, # the integer should be converted to str by SDK
ground_truth_attribute=None,
probability_threshold_attribute=0.5, # the float should be converted to str by SDK
inference_attribute=quality_cfg_attr_value,
probability_attribute=quality_cfg_attr_value,
ground_truth_attribute=quality_cfg_attr_value,
probability_threshold_attribute=quality_cfg_attr_value,
post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py",
output_s3_uri="",
)

if expected_value_in_dsl == "Error":
with pytest.raises(ValueError) as err:
QualityCheckStep(
name="ModelQualityCheckStep",
register_new_baseline=False,
skip_check=False,
fail_on_violation=True,
quality_check_config=model_quality_check_config,
check_job_config=check_job_config,
model_package_group_name=model_package_group_name,
supplied_baseline_statistics=supplied_baseline_statistics_uri,
supplied_baseline_constraints=supplied_baseline_constraints_uri,
)
assert "cannot be Parameter types other than ParameterString" in str(err)
return

model_quality_check_step = QualityCheckStep(
name="ModelQualityCheckStep",
register_new_baseline=False,
Expand All @@ -297,6 +327,7 @@ def test_model_quality_check_step(
supplied_baseline_statistics=supplied_baseline_statistics_uri,
supplied_baseline_constraints=supplied_baseline_constraints_uri,
)

pipeline = Pipeline(
name="MyPipeline",
parameters=[
Expand All @@ -310,6 +341,16 @@ def test_model_quality_check_step(

step_definition = _get_step_definition_for_test(pipeline)

step_def_env = step_definition["Arguments"]["Environment"]
for var in [
"inference_attribute",
"probability_attribute",
"ground_truth_attribute",
"probability_threshold_attribute",
]:
env_var_dsl = step_def_env.pop(var, None)
assert env_var_dsl == expected_value_in_dsl

assert step_definition == _expected_model_quality_dsl


Expand Down