diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 78d81abc08..b12afaad40 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -186,8 +186,8 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "Type": self.condition_type.value, - "Value": self.value.expr, - "In": [primitive_or_expr(in_value) for in_value in self.in_values], + "QueryValue": self.value.expr, + "Values": [primitive_or_expr(in_value) for in_value in self.in_values], } diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 6064d67270..66f44836b0 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -38,7 +38,7 @@ from sagemaker.sklearn.estimator import SKLearn from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor -from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo +from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo, ConditionIn from sagemaker.workflow.condition_step import ConditionStep from sagemaker.wrangler.processing import DataWranglerProcessor from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition @@ -696,6 +696,7 @@ def test_conditional_pytorch_training_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1) + in_condition_input = ParameterString(name="Foo", default_value="Foo") pytorch_estimator = PyTorch( entry_point=entry_point, @@ -741,14 +742,17 @@ def test_conditional_pytorch_training_model_registration( step_cond = ConditionStep( name="cond-good-enough", - conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)], + conditions=[ + ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1), + ConditionIn(value=in_condition_input, in_values=["foo", "bar"]), + ], if_steps=[step_train, step_register], else_steps=[step_model], ) pipeline = Pipeline( name=pipeline_name, - parameters=[good_enough_input, instance_count, instance_type], + parameters=[in_condition_input, good_enough_input, instance_count, instance_type], steps=[step_cond], sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index bff85cc12b..42438b63c8 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -99,8 +99,8 @@ def test_condition_in(): cond_in = ConditionIn(value=param, in_values=["abc", "def"]) assert cond_in.to_request() == { "Type": "In", - "Value": {"Get": "Parameters.MyStr"}, - "In": ["abc", "def"], + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], } @@ -111,8 +111,8 @@ def test_condition_in_mixed(): cond_in = ConditionIn(value=param, in_values=["abc", prop, var]) assert cond_in.to_request() == { "Type": "In", - "Value": {"Get": "Parameters.MyStr"}, - "In": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}], + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}], } @@ -138,8 +138,8 @@ def test_condition_not_in(): "Type": "Not", "Expression": { "Type": "In", - "Value": {"Get": "Parameters.MyStr"}, - "In": ["abc", "def"], + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], }, } @@ -160,8 +160,8 @@ def test_condition_or(): }, { "Type": "In", - "Value": {"Get": "Parameters.MyStr"}, - "In": ["abc", "def"], + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], }, ], }