Skip to content

Commit cda59d7

Browse files
committed
Add CASE and IF statement support
Add support for scripting statements ```sql CASE product_id WHEN 1 THEN SELECT 1; WHEN 2 THEN SELECT 2; ELSE SELECT 3; END CASE; ``` ```sql IF EXISTS(SELECT 1) THEN SELECT 1; ELSEIF EXISTS(SELECT 2) THEN SELECT 2; ELSE SELECT 3; END IF; ``` [BigQuery CASE](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case) [BigQuery IF](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if) [Snowflake CASE](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case) [Snowflake IF](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
1 parent aab12ad commit cda59d7

File tree

5 files changed

+472
-21
lines changed

5 files changed

+472
-21
lines changed

src/ast/mod.rs

+190-8
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ where
149149
DisplaySeparated { slice, sep: ", " }
150150
}
151151

152+
/// Writes the given statements to the formatter, each ending with
153+
/// a semicolon and space separated.
154+
fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) -> fmt::Result {
155+
write!(f, "{}", display_separated(statements, "; "))?;
156+
// We manually insert semicolon for the last statement,
157+
// since display_separated doesn't handle that case.
158+
write!(f, ";")
159+
}
160+
152161
/// An identifier, decomposed into its value or character data and the quote style.
153162
#[derive(Debug, Clone, PartialOrd, Ord)]
154163
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -2070,6 +2079,173 @@ pub enum Password {
20702079
NullPassword,
20712080
}
20722081

2082+
/// A `CASE` statement.
2083+
///
2084+
/// Examples:
2085+
/// ```sql
2086+
/// CASE
2087+
/// WHEN EXISTS(SELECT 1)
2088+
/// THEN SELECT 1 FROM T;
2089+
/// WHEN EXISTS(SELECT 2)
2090+
/// THEN SELECT 1 FROM U;
2091+
/// ELSE
2092+
/// SELECT 1 FROM V;
2093+
/// END CASE;
2094+
/// ```
2095+
///
2096+
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression)
2097+
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case)
2098+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2099+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2100+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2101+
pub struct CaseStatement {
2102+
pub match_expr: Option<Expr>,
2103+
pub when_blocks: Vec<ConditionalStatements>,
2104+
pub else_block: Option<Vec<Statement>>,
2105+
/// TRUE if the statement ends with `END CASE` (vs `END`).
2106+
pub has_end_case: bool,
2107+
}
2108+
2109+
impl fmt::Display for CaseStatement {
2110+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2111+
let CaseStatement {
2112+
match_expr,
2113+
when_blocks,
2114+
else_block,
2115+
has_end_case,
2116+
} = self;
2117+
2118+
write!(f, "CASE")?;
2119+
2120+
if let Some(expr) = match_expr {
2121+
write!(f, " {expr}")?;
2122+
}
2123+
2124+
if !when_blocks.is_empty() {
2125+
write!(f, " {}", display_separated(when_blocks, " "))?;
2126+
}
2127+
2128+
if let Some(else_block) = else_block {
2129+
write!(f, " ELSE ")?;
2130+
format_statement_list(f, else_block)?;
2131+
}
2132+
2133+
write!(f, " END")?;
2134+
if *has_end_case {
2135+
write!(f, " CASE")?;
2136+
}
2137+
2138+
Ok(())
2139+
}
2140+
}
2141+
2142+
/// An `IF` statement.
2143+
///
2144+
/// Examples:
2145+
/// ```sql
2146+
/// IF TRUE THEN
2147+
/// SELECT 1;
2148+
/// SELECT 2;
2149+
/// ELSEIF TRUE THEN
2150+
/// SELECT 3;
2151+
/// ELSE
2152+
/// SELECT 4;
2153+
/// END IF
2154+
/// ```
2155+
///
2156+
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
2157+
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
2158+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2159+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2160+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2161+
pub struct IfStatement {
2162+
pub if_block: ConditionalStatements,
2163+
pub elseif_blocks: Vec<ConditionalStatements>,
2164+
pub else_block: Option<Vec<Statement>>,
2165+
}
2166+
2167+
impl fmt::Display for IfStatement {
2168+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2169+
let IfStatement {
2170+
if_block,
2171+
elseif_blocks,
2172+
else_block,
2173+
} = self;
2174+
2175+
write!(f, "{if_block}")?;
2176+
2177+
if !elseif_blocks.is_empty() {
2178+
write!(f, " {}", display_separated(elseif_blocks, " "))?;
2179+
}
2180+
2181+
if let Some(else_block) = else_block {
2182+
write!(f, " ELSE ")?;
2183+
format_statement_list(f, else_block)?;
2184+
}
2185+
2186+
write!(f, " END IF")?;
2187+
2188+
Ok(())
2189+
}
2190+
}
2191+
2192+
/// Represents a type of [ConditionalStatements]
2193+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2194+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2195+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2196+
pub enum ConditionalStatementKind {
2197+
/// `WHEN <condition> THEN <statements>`
2198+
When,
2199+
/// `IF <condition> THEN <statements>`
2200+
If,
2201+
/// `ELSEIF <condition> THEN <statements>`
2202+
ElseIf,
2203+
}
2204+
2205+
/// A block within a [Statement::Case] or [Statement::If]-like statement
2206+
///
2207+
/// Examples:
2208+
/// ```sql
2209+
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
2210+
///
2211+
/// IF TRUE THEN SELECT 1; SELECT 2;
2212+
/// ```
2213+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2214+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2215+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2216+
pub struct ConditionalStatements {
2217+
/// The condition expression.
2218+
pub condition: Expr,
2219+
/// Statement list of the `THEN` clause.
2220+
pub statements: Vec<Statement>,
2221+
pub kind: ConditionalStatementKind,
2222+
}
2223+
2224+
impl fmt::Display for ConditionalStatements {
2225+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2226+
let ConditionalStatements {
2227+
condition: expr,
2228+
statements,
2229+
kind,
2230+
} = self;
2231+
2232+
let kind = match kind {
2233+
ConditionalStatementKind::When => "WHEN",
2234+
ConditionalStatementKind::If => "IF",
2235+
ConditionalStatementKind::ElseIf => "ELSEIF",
2236+
};
2237+
2238+
write!(f, "{kind} {expr} THEN")?;
2239+
2240+
if !statements.is_empty() {
2241+
write!(f, " ")?;
2242+
format_statement_list(f, statements)?;
2243+
}
2244+
2245+
Ok(())
2246+
}
2247+
}
2248+
20732249
/// Represents an expression assignment within a variable `DECLARE` statement.
20742250
///
20752251
/// Examples:
@@ -2474,6 +2650,10 @@ pub enum Statement {
24742650
file_format: Option<FileFormat>,
24752651
source: Box<Query>,
24762652
},
2653+
/// A `CASE` statement.
2654+
Case(CaseStatement),
2655+
/// An `IF` statement.
2656+
If(IfStatement),
24772657
/// ```sql
24782658
/// CALL <function>
24792659
/// ```
@@ -3805,6 +3985,12 @@ impl fmt::Display for Statement {
38053985
}
38063986
Ok(())
38073987
}
3988+
Statement::Case(stmt) => {
3989+
write!(f, "{stmt}")
3990+
}
3991+
Statement::If(stmt) => {
3992+
write!(f, "{stmt}")
3993+
}
38083994
Statement::AttachDatabase {
38093995
schema_name,
38103996
database_file_name,
@@ -4857,18 +5043,14 @@ impl fmt::Display for Statement {
48575043
write!(f, " {}", display_comma_separated(modes))?;
48585044
}
48595045
if !statements.is_empty() {
4860-
write!(f, " {}", display_separated(statements, "; "))?;
4861-
// We manually insert semicolon for the last statement,
4862-
// since display_separated doesn't handle that case.
4863-
write!(f, ";")?;
5046+
write!(f, " ")?;
5047+
format_statement_list(f, statements)?;
48645048
}
48655049
if let Some(exception_statements) = exception_statements {
48665050
write!(f, " EXCEPTION WHEN ERROR THEN")?;
48675051
if !exception_statements.is_empty() {
4868-
write!(f, " {}", display_separated(exception_statements, "; "))?;
4869-
// We manually insert semicolon for the last statement,
4870-
// since display_separated doesn't handle that case.
4871-
write!(f, ";")?;
5052+
write!(f, " ")?;
5053+
format_statement_list(f, exception_statements)?;
48725054
}
48735055
}
48745056
if *has_end_keyword {

src/ast/spans.rs

+63-13
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,20 @@ use crate::tokenizer::Span;
2222

2323
use super::{
2424
dcl::SecondaryRoles, AccessExpr, AlterColumnOperation, AlterIndexOperation,
25-
AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor, ClusteredIndex,
26-
ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, ConstraintCharacteristics,
27-
CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate,
28-
ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function,
29-
FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments,
30-
GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join,
31-
JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, MatchRecognizePattern,
32-
Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict,
33-
OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource,
34-
ProjectionSelect, Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement,
35-
ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript,
36-
SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject,
37-
TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
25+
AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement, CloseCursor,
26+
ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements,
27+
ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
28+
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
29+
ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
30+
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
31+
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
32+
JoinOperator, JsonPath, JsonPathElem, LateralView, MatchRecognizePattern, Measure,
33+
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
34+
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
35+
Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
36+
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
37+
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
38+
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
3839
WildcardAdditionalOptions, With, WithFill,
3940
};
4041

@@ -323,6 +324,8 @@ impl Spanned for Statement {
323324
file_format: _,
324325
source,
325326
} => source.span(),
327+
Statement::Case(stmt) => stmt.span(),
328+
Statement::If(stmt) => stmt.span(),
326329
Statement::Call(function) => function.span(),
327330
Statement::Copy {
328331
source,
@@ -728,6 +731,53 @@ impl Spanned for CreateIndex {
728731
}
729732
}
730733

734+
impl Spanned for CaseStatement {
735+
fn span(&self) -> Span {
736+
let CaseStatement {
737+
match_expr,
738+
when_blocks,
739+
else_block,
740+
has_end_case: _,
741+
} = self;
742+
743+
union_spans(
744+
match_expr
745+
.iter()
746+
.map(|e| e.span())
747+
.chain(when_blocks.iter().map(|b| b.span()))
748+
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
749+
)
750+
}
751+
}
752+
753+
impl Spanned for IfStatement {
754+
fn span(&self) -> Span {
755+
let IfStatement {
756+
if_block,
757+
elseif_blocks,
758+
else_block,
759+
} = self;
760+
761+
union_spans(
762+
iter::once(if_block.span())
763+
.chain(elseif_blocks.iter().map(|b| b.span()))
764+
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
765+
)
766+
}
767+
}
768+
769+
impl Spanned for ConditionalStatements {
770+
fn span(&self) -> Span {
771+
let ConditionalStatements {
772+
condition,
773+
statements,
774+
kind: _,
775+
} = self;
776+
777+
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
778+
}
779+
}
780+
731781
/// # partial span
732782
///
733783
/// Missing spans:

src/keywords.rs

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ define_keywords!(
294294
ELEMENT,
295295
ELEMENTS,
296296
ELSE,
297+
ELSEIF,
297298
EMPTY,
298299
ENABLE,
299300
ENABLE_SCHEMA_EVOLUTION,

0 commit comments

Comments
 (0)