Skip to content

Commit abe8399

Browse files
fix: fix in and not in condition bug (#2347)
Co-authored-by: icywang86rui <[email protected]>
1 parent 496fe92 commit abe8399

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

src/sagemaker/workflow/conditions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ def to_request(self) -> RequestType:
186186
"""Get the request structure for workflow service calls."""
187187
return {
188188
"Type": self.condition_type.value,
189-
"Value": self.value.expr,
190-
"In": [primitive_or_expr(in_value) for in_value in self.in_values],
189+
"QueryValue": self.value.expr,
190+
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
191191
}
192192

193193

tests/integ/test_workflow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sagemaker.sklearn.estimator import SKLearn
3939
from sagemaker.sklearn.processing import SKLearnProcessor
4040
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
41-
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
41+
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo, ConditionIn
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.wrangler.processing import DataWranglerProcessor
4444
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
@@ -696,6 +696,7 @@ def test_conditional_pytorch_training_model_registration(
696696
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
697697
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
698698
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
699+
in_condition_input = ParameterString(name="Foo", default_value="Foo")
699700

700701
pytorch_estimator = PyTorch(
701702
entry_point=entry_point,
@@ -741,14 +742,17 @@ def test_conditional_pytorch_training_model_registration(
741742

742743
step_cond = ConditionStep(
743744
name="cond-good-enough",
744-
conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)],
745+
conditions=[
746+
ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1),
747+
ConditionIn(value=in_condition_input, in_values=["foo", "bar"]),
748+
],
745749
if_steps=[step_train, step_register],
746750
else_steps=[step_model],
747751
)
748752

749753
pipeline = Pipeline(
750754
name=pipeline_name,
751-
parameters=[good_enough_input, instance_count, instance_type],
755+
parameters=[in_condition_input, good_enough_input, instance_count, instance_type],
752756
steps=[step_cond],
753757
sagemaker_session=sagemaker_session,
754758
)

tests/unit/sagemaker/workflow/test_conditions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_condition_in():
9999
cond_in = ConditionIn(value=param, in_values=["abc", "def"])
100100
assert cond_in.to_request() == {
101101
"Type": "In",
102-
"Value": {"Get": "Parameters.MyStr"},
103-
"In": ["abc", "def"],
102+
"QueryValue": {"Get": "Parameters.MyStr"},
103+
"Values": ["abc", "def"],
104104
}
105105

106106

@@ -111,8 +111,8 @@ def test_condition_in_mixed():
111111
cond_in = ConditionIn(value=param, in_values=["abc", prop, var])
112112
assert cond_in.to_request() == {
113113
"Type": "In",
114-
"Value": {"Get": "Parameters.MyStr"},
115-
"In": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}],
114+
"QueryValue": {"Get": "Parameters.MyStr"},
115+
"Values": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}],
116116
}
117117

118118

@@ -138,8 +138,8 @@ def test_condition_not_in():
138138
"Type": "Not",
139139
"Expression": {
140140
"Type": "In",
141-
"Value": {"Get": "Parameters.MyStr"},
142-
"In": ["abc", "def"],
141+
"QueryValue": {"Get": "Parameters.MyStr"},
142+
"Values": ["abc", "def"],
143143
},
144144
}
145145

@@ -160,8 +160,8 @@ def test_condition_or():
160160
},
161161
{
162162
"Type": "In",
163-
"Value": {"Get": "Parameters.MyStr"},
164-
"In": ["abc", "def"],
163+
"QueryValue": {"Get": "Parameters.MyStr"},
164+
"Values": ["abc", "def"],
165165
},
166166
],
167167
}

0 commit comments

Comments
 (0)