From e342f331837e628f9f5501c7522eee121e10b7ab Mon Sep 17 00:00:00 2001 From: qidewenwhen Date: Wed, 3 Jan 2024 16:26:50 -0800 Subject: [PATCH] fix: Support PipelineVariable ModelQualityCheckConfig attributes --- src/sagemaker/workflow/quality_check_step.py | 52 ++++++++++++------ .../workflow/test_quality_check_step.py | 53 ++++++++++++++++--- 2 files changed, 82 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index 2cae687770..8257ed3844 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -13,6 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import +import logging from abc import ABC from typing import List, Union, Optional import os @@ -24,7 +25,8 @@ from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput from sagemaker.workflow import is_pipeline_variable -from sagemaker.workflow.entities import RequestType, PipelineVariable +from sagemaker.workflow.entities import RequestType, PipelineVariable, PrimitiveType +from sagemaker.workflow.parameters import Parameter, ParameterString from sagemaker.workflow.properties import ( Properties, ) @@ -47,6 +49,9 @@ _DATA_QUALITY_TYPE = "DATA_QUALITY" +logger = logging.getLogger(__name__) + + @attr.s class QualityCheckConfig(ABC): """Quality Check Config. @@ -407,25 +412,19 @@ def _generate_baseline_processor( post_processor_script_container_path=post_processor_script_container_path, ) else: - inference_attribute = ( - str(quality_check_cfg.inference_attribute) - if quality_check_cfg.inference_attribute is not None - else None + inference_attribute = _format_env_variable_value( + var_value=quality_check_cfg.inference_attribute, var_name="inference_attribute" ) - probability_attribute = ( - str(quality_check_cfg.probability_attribute) - if quality_check_cfg.probability_attribute is not None - else None + probability_attribute = _format_env_variable_value( + var_value=quality_check_cfg.probability_attribute, var_name="probability_attribute" ) - ground_truth_attribute = ( - str(quality_check_cfg.ground_truth_attribute) - if quality_check_cfg.ground_truth_attribute is not None - else None + ground_truth_attribute = _format_env_variable_value( + var_value=quality_check_cfg.ground_truth_attribute, + var_name="ground_truth_attribute", ) - probability_threshold_attr = ( - str(quality_check_cfg.probability_threshold_attribute) - if quality_check_cfg.probability_threshold_attribute is not None - else None + probability_threshold_attr = _format_env_variable_value( + var_value=quality_check_cfg.probability_threshold_attribute, + var_name="probability_threshold_attr", ) normalized_env = ModelMonitor._generate_env_map( env=self._model_monitor.env, @@ -458,3 +457,22 @@ def _generate_baseline_processor( tags=self._model_monitor.tags, network_config=self._model_monitor.network_config, ) + + +def _format_env_variable_value(var_value: Union[PrimitiveType, PipelineVariable], var_name: str): + """Helper function to format the variable values passed to env var + + Args: + var_value (PrimitiveType or PipelineVariable): The value of the variable. + var_name (str): The name of the variable. + """ + if var_value is None: + return None + + if is_pipeline_variable(var_value): + if isinstance(var_value, Parameter) and not isinstance(var_value, ParameterString): + raise ValueError(f"{var_name} cannot be Parameter types other than ParameterString.") + logger.warning("%s's runtime value must be the string type.", var_name) + return var_value + + return str(var_value) diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index 07dc37bafd..88125b714c 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -16,7 +16,8 @@ import pytest from sagemaker.model_monitor import DatasetFormat -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.execution_variables import ExecutionVariable +from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.pipeline import PipelineDefinitionConfig from sagemaker.workflow.quality_check_step import ( @@ -178,8 +179,6 @@ "dataset_source": "/opt/ml/processing/input/baseline_dataset_input", "analysis_type": "MODEL_QUALITY", "problem_type": "BinaryClassification", - "probability_attribute": "0", - "probability_threshold_attribute": "0.5", }, "StoppingCondition": {"MaxRuntimeInSeconds": 1800}, }, @@ -269,23 +268,54 @@ def test_data_quality_check_step( assert step_definition == _expected_data_quality_dsl +@pytest.mark.parametrize( + "quality_cfg_attr_value, expected_value_in_dsl", + [ + (0, "0"), + ("attr", "attr"), + (None, None), + (ParameterString(name="ParamStringEnvVar"), {"Get": "Parameters.ParamStringEnvVar"}), + (ExecutionVariable("PipelineArn"), {"Get": "Execution.PipelineArn"}), + (ParameterInteger(name="ParamIntEnvVar"), "Error"), + ], +) def test_model_quality_check_step( sagemaker_session, check_job_config, model_package_group_name, supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, + quality_cfg_attr_value, + expected_value_in_dsl, ): model_quality_check_config = ModelQualityCheckConfig( baseline_dataset="baseline_dataset_s3_url", dataset_format=DatasetFormat.csv(header=True), problem_type="BinaryClassification", - probability_attribute=0, # the integer should be converted to str by SDK - ground_truth_attribute=None, - probability_threshold_attribute=0.5, # the float should be converted to str by SDK + inference_attribute=quality_cfg_attr_value, + probability_attribute=quality_cfg_attr_value, + ground_truth_attribute=quality_cfg_attr_value, + probability_threshold_attribute=quality_cfg_attr_value, post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py", output_s3_uri="", ) + + if expected_value_in_dsl == "Error": + with pytest.raises(ValueError) as err: + QualityCheckStep( + name="ModelQualityCheckStep", + register_new_baseline=False, + skip_check=False, + fail_on_violation=True, + quality_check_config=model_quality_check_config, + check_job_config=check_job_config, + model_package_group_name=model_package_group_name, + supplied_baseline_statistics=supplied_baseline_statistics_uri, + supplied_baseline_constraints=supplied_baseline_constraints_uri, + ) + assert "cannot be Parameter types other than ParameterString" in str(err) + return + model_quality_check_step = QualityCheckStep( name="ModelQualityCheckStep", register_new_baseline=False, @@ -297,6 +327,7 @@ def test_model_quality_check_step( supplied_baseline_statistics=supplied_baseline_statistics_uri, supplied_baseline_constraints=supplied_baseline_constraints_uri, ) + pipeline = Pipeline( name="MyPipeline", parameters=[ @@ -310,6 +341,16 @@ def test_model_quality_check_step( step_definition = _get_step_definition_for_test(pipeline) + step_def_env = step_definition["Arguments"]["Environment"] + for var in [ + "inference_attribute", + "probability_attribute", + "ground_truth_attribute", + "probability_threshold_attribute", + ]: + env_var_dsl = step_def_env.pop(var, None) + assert env_var_dsl == expected_value_in_dsl + assert step_definition == _expected_model_quality_dsl