Skip to content

fix in and not in condition bug #2347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
return {
"Type": self.condition_type.value,
"Value": self.value.expr,
"In": [primitive_or_expr(in_value) for in_value in self.in_values],
"QueryValue": self.value.expr,
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
}


Expand Down
10 changes: 7 additions & 3 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo, ConditionIn
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
Expand Down Expand Up @@ -696,6 +696,7 @@ def test_conditional_pytorch_training_model_registration(
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
in_condition_input = ParameterString(name="Foo", default_value="Foo")

pytorch_estimator = PyTorch(
entry_point=entry_point,
Expand Down Expand Up @@ -741,14 +742,17 @@ def test_conditional_pytorch_training_model_registration(

step_cond = ConditionStep(
name="cond-good-enough",
conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)],
conditions=[
ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1),
ConditionIn(value=in_condition_input, in_values=["foo", "bar"]),
],
if_steps=[step_train, step_register],
else_steps=[step_model],
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[good_enough_input, instance_count, instance_type],
parameters=[in_condition_input, good_enough_input, instance_count, instance_type],
steps=[step_cond],
sagemaker_session=sagemaker_session,
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/sagemaker/workflow/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_condition_in():
cond_in = ConditionIn(value=param, in_values=["abc", "def"])
assert cond_in.to_request() == {
"Type": "In",
"Value": {"Get": "Parameters.MyStr"},
"In": ["abc", "def"],
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
}


Expand All @@ -111,8 +111,8 @@ def test_condition_in_mixed():
cond_in = ConditionIn(value=param, in_values=["abc", prop, var])
assert cond_in.to_request() == {
"Type": "In",
"Value": {"Get": "Parameters.MyStr"},
"In": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}],
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}],
}


Expand All @@ -138,8 +138,8 @@ def test_condition_not_in():
"Type": "Not",
"Expression": {
"Type": "In",
"Value": {"Get": "Parameters.MyStr"},
"In": ["abc", "def"],
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
},
}

Expand All @@ -160,8 +160,8 @@ def test_condition_or():
},
{
"Type": "In",
"Value": {"Get": "Parameters.MyStr"},
"In": ["abc", "def"],
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
},
],
}