Skip to content

Commit 72312ba

Browse files
authored
Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case (#1733)
1 parent 7fc37a7 commit 72312ba

File tree

5 files changed

+160
-50
lines changed

5 files changed

+160
-50
lines changed

src/ast/mod.rs

+19-6
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,22 @@ pub enum CeilFloorKind {
600600
Scale(Value),
601601
}
602602

603+
/// A WHEN clause in a CASE expression containing both
604+
/// the condition and its corresponding result
605+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
606+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
607+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
608+
pub struct CaseWhen {
609+
pub condition: Expr,
610+
pub result: Expr,
611+
}
612+
613+
impl fmt::Display for CaseWhen {
614+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
615+
write!(f, "WHEN {} THEN {}", self.condition, self.result)
616+
}
617+
}
618+
603619
/// An SQL expression of any type.
604620
///
605621
/// # Semantics / Type Checking
@@ -918,8 +934,7 @@ pub enum Expr {
918934
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
919935
Case {
920936
operand: Option<Box<Expr>>,
921-
conditions: Vec<Expr>,
922-
results: Vec<Expr>,
937+
conditions: Vec<CaseWhen>,
923938
else_result: Option<Box<Expr>>,
924939
},
925940
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@@ -1621,17 +1636,15 @@ impl fmt::Display for Expr {
16211636
Expr::Case {
16221637
operand,
16231638
conditions,
1624-
results,
16251639
else_result,
16261640
} => {
16271641
write!(f, "CASE")?;
16281642
if let Some(operand) = operand {
16291643
write!(f, " {operand}")?;
16301644
}
1631-
for (c, r) in conditions.iter().zip(results) {
1632-
write!(f, " WHEN {c} THEN {r}")?;
1645+
for when in conditions {
1646+
write!(f, " {when}")?;
16331647
}
1634-
16351648
if let Some(else_result) = else_result {
16361649
write!(f, " ELSE {else_result}")?;
16371650
}

src/ast/spans.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1450,15 +1450,15 @@ impl Spanned for Expr {
14501450
Expr::Case {
14511451
operand,
14521452
conditions,
1453-
results,
14541453
else_result,
14551454
} => union_spans(
14561455
operand
14571456
.as_ref()
14581457
.map(|i| i.span())
14591458
.into_iter()
1460-
.chain(conditions.iter().map(|i| i.span()))
1461-
.chain(results.iter().map(|i| i.span()))
1459+
.chain(conditions.iter().flat_map(|case_when| {
1460+
[case_when.condition.span(), case_when.result.span()]
1461+
}))
14621462
.chain(else_result.as_ref().map(|i| i.span())),
14631463
),
14641464
Expr::Exists { subquery, .. } => subquery.span(),

src/parser/mod.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
20652065
self.expect_keyword_is(Keyword::WHEN)?;
20662066
}
20672067
let mut conditions = vec![];
2068-
let mut results = vec![];
20692068
loop {
2070-
conditions.push(self.parse_expr()?);
2069+
let condition = self.parse_expr()?;
20712070
self.expect_keyword_is(Keyword::THEN)?;
2072-
results.push(self.parse_expr()?);
2071+
let result = self.parse_expr()?;
2072+
conditions.push(CaseWhen { condition, result });
20732073
if !self.parse_keyword(Keyword::WHEN) {
20742074
break;
20752075
}
@@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
20832083
Ok(Expr::Case {
20842084
operand,
20852085
conditions,
2086-
results,
20872086
else_result,
20882087
})
20892088
}

tests/sqlparser_common.rs

+70-37
Original file line numberDiff line numberDiff line change
@@ -6695,22 +6695,26 @@ fn parse_searched_case_expr() {
66956695
&Case {
66966696
operand: None,
66976697
conditions: vec![
6698-
IsNull(Box::new(Identifier(Ident::new("bar")))),
6699-
BinaryOp {
6700-
left: Box::new(Identifier(Ident::new("bar"))),
6701-
op: Eq,
6702-
right: Box::new(Expr::Value(number("0"))),
6698+
CaseWhen {
6699+
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
6700+
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
67036701
},
6704-
BinaryOp {
6705-
left: Box::new(Identifier(Ident::new("bar"))),
6706-
op: GtEq,
6707-
right: Box::new(Expr::Value(number("0"))),
6702+
CaseWhen {
6703+
condition: BinaryOp {
6704+
left: Box::new(Identifier(Ident::new("bar"))),
6705+
op: Eq,
6706+
right: Box::new(Expr::Value(number("0"))),
6707+
},
6708+
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
6709+
},
6710+
CaseWhen {
6711+
condition: BinaryOp {
6712+
left: Box::new(Identifier(Ident::new("bar"))),
6713+
op: GtEq,
6714+
right: Box::new(Expr::Value(number("0"))),
6715+
},
6716+
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
67086717
},
6709-
],
6710-
results: vec![
6711-
Expr::Value(Value::SingleQuotedString("null".to_string())),
6712-
Expr::Value(Value::SingleQuotedString("=0".to_string())),
6713-
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
67146718
],
67156719
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
67166720
"<0".to_string()
@@ -6729,8 +6733,10 @@ fn parse_simple_case_expr() {
67296733
assert_eq!(
67306734
&Case {
67316735
operand: Some(Box::new(Identifier(Ident::new("foo")))),
6732-
conditions: vec![Expr::Value(number("1"))],
6733-
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
6736+
conditions: vec![CaseWhen {
6737+
condition: Expr::Value(number("1")),
6738+
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
6739+
}],
67346740
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
67356741
"N".to_string()
67366742
)))),
@@ -13902,6 +13908,31 @@ fn test_trailing_commas_in_from() {
1390213908
);
1390313909
}
1390413910

13911+
#[test]
13912+
#[cfg(feature = "visitor")]
13913+
fn test_visit_order() {
13914+
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
13915+
let stmt = verified_stmt(sql);
13916+
let mut visited = vec![];
13917+
sqlparser::ast::visit_expressions(&stmt, |expr| {
13918+
visited.push(expr.to_string());
13919+
core::ops::ControlFlow::<()>::Continue(())
13920+
});
13921+
13922+
assert_eq!(
13923+
visited,
13924+
[
13925+
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
13926+
"a",
13927+
"1",
13928+
"2",
13929+
"3",
13930+
"4",
13931+
"5"
13932+
]
13933+
);
13934+
}
13935+
1390513936
#[test]
1390613937
fn test_lambdas() {
1390713938
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@@ -13929,28 +13960,30 @@ fn test_lambdas() {
1392913960
body: Box::new(Expr::Case {
1393013961
operand: None,
1393113962
conditions: vec![
13932-
Expr::BinaryOp {
13933-
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13934-
op: BinaryOperator::Eq,
13935-
right: Box::new(Expr::Identifier(Ident::new("p2")))
13963+
CaseWhen {
13964+
condition: Expr::BinaryOp {
13965+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
13966+
op: BinaryOperator::Eq,
13967+
right: Box::new(Expr::Identifier(Ident::new("p2")))
13968+
},
13969+
result: Expr::Value(number("0"))
1393613970
},
13937-
Expr::BinaryOp {
13938-
left: Box::new(call(
13939-
"reverse",
13940-
[Expr::Identifier(Ident::new("p1"))]
13941-
)),
13942-
op: BinaryOperator::Lt,
13943-
right: Box::new(call(
13944-
"reverse",
13945-
[Expr::Identifier(Ident::new("p2"))]
13946-
))
13947-
}
13948-
],
13949-
results: vec![
13950-
Expr::Value(number("0")),
13951-
Expr::UnaryOp {
13952-
op: UnaryOperator::Minus,
13953-
expr: Box::new(Expr::Value(number("1")))
13971+
CaseWhen {
13972+
condition: Expr::BinaryOp {
13973+
left: Box::new(call(
13974+
"reverse",
13975+
[Expr::Identifier(Ident::new("p1"))]
13976+
)),
13977+
op: BinaryOperator::Lt,
13978+
right: Box::new(call(
13979+
"reverse",
13980+
[Expr::Identifier(Ident::new("p2"))]
13981+
))
13982+
},
13983+
result: Expr::UnaryOp {
13984+
op: UnaryOperator::Minus,
13985+
expr: Box::new(Expr::Value(number("1")))
13986+
}
1395413987
}
1395513988
],
1395613989
else_result: Some(Box::new(Expr::Value(number("1"))))

tests/sqlparser_databricks.rs

+65
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ fn test_databricks_exists() {
8383
);
8484
}
8585

86+
#[test]
87+
fn test_databricks_lambdas() {
88+
#[rustfmt::skip]
89+
let sql = concat!(
90+
"SELECT array_sort(array('Hello', 'World'), ",
91+
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
92+
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
93+
"ELSE 1 END)",
94+
);
95+
pretty_assertions::assert_eq!(
96+
SelectItem::UnnamedExpr(call(
97+
"array_sort",
98+
[
99+
call(
100+
"array",
101+
[
102+
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
103+
Expr::Value(Value::SingleQuotedString("World".to_owned()))
104+
]
105+
),
106+
Expr::Lambda(LambdaFunction {
107+
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
108+
body: Box::new(Expr::Case {
109+
operand: None,
110+
conditions: vec![
111+
CaseWhen {
112+
condition: Expr::BinaryOp {
113+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
114+
op: BinaryOperator::Eq,
115+
right: Box::new(Expr::Identifier(Ident::new("p2")))
116+
},
117+
result: Expr::Value(number("0"))
118+
},
119+
CaseWhen {
120+
condition: Expr::BinaryOp {
121+
left: Box::new(call(
122+
"reverse",
123+
[Expr::Identifier(Ident::new("p1"))]
124+
)),
125+
op: BinaryOperator::Lt,
126+
right: Box::new(call(
127+
"reverse",
128+
[Expr::Identifier(Ident::new("p2"))]
129+
)),
130+
},
131+
result: Expr::UnaryOp {
132+
op: UnaryOperator::Minus,
133+
expr: Box::new(Expr::Value(number("1")))
134+
}
135+
},
136+
],
137+
else_result: Some(Box::new(Expr::Value(number("1"))))
138+
})
139+
})
140+
]
141+
)),
142+
databricks().verified_only_select(sql).projection[0]
143+
);
144+
145+
databricks().verified_expr(
146+
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
147+
);
148+
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
149+
}
150+
86151
#[test]
87152
fn test_values_clause() {
88153
let values = Values {

0 commit comments

Comments
 (0)