Skip to content

Commit 4d48fde

Browse files
romanbRoman Borschel
authored andcommitted
Add support for MSSQL IF/ELSE statements. (apache#1791)
Co-authored-by: Roman Borschel <[email protected]>
1 parent c7bdd76 commit 4d48fde

File tree

6 files changed

+525
-129
lines changed

6 files changed

+525
-129
lines changed

src/ast/mod.rs

+112-54
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ use serde::{Deserialize, Serialize};
3737
#[cfg(feature = "visitor")]
3838
use sqlparser_derive::{Visit, VisitMut};
3939

40-
use crate::tokenizer::Span;
40+
use crate::keywords::Keyword;
41+
use crate::tokenizer::{Span, Token};
4142

4243
pub use self::data_type::{
4344
ArrayElemTypeDef, BinaryLength, CharLengthUnits, CharacterLength, DataType, EnumMember,
@@ -2136,20 +2137,23 @@ pub enum Password {
21362137
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21372138
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
21382139
pub struct CaseStatement {
2140+
/// The `CASE` token that starts the statement.
2141+
pub case_token: AttachedToken,
21392142
pub match_expr: Option<Expr>,
2140-
pub when_blocks: Vec<ConditionalStatements>,
2141-
pub else_block: Option<Vec<Statement>>,
2142-
/// TRUE if the statement ends with `END CASE` (vs `END`).
2143-
pub has_end_case: bool,
2143+
pub when_blocks: Vec<ConditionalStatementBlock>,
2144+
pub else_block: Option<ConditionalStatementBlock>,
2145+
/// The last token of the statement (`END` or `CASE`).
2146+
pub end_case_token: AttachedToken,
21442147
}
21452148

21462149
impl fmt::Display for CaseStatement {
21472150
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21482151
let CaseStatement {
2152+
case_token: _,
21492153
match_expr,
21502154
when_blocks,
21512155
else_block,
2152-
has_end_case,
2156+
end_case_token: AttachedToken(end),
21532157
} = self;
21542158

21552159
write!(f, "CASE")?;
@@ -2163,13 +2167,15 @@ impl fmt::Display for CaseStatement {
21632167
}
21642168

21652169
if let Some(else_block) = else_block {
2166-
write!(f, " ELSE ")?;
2167-
format_statement_list(f, else_block)?;
2170+
write!(f, " {else_block}")?;
21682171
}
21692172

21702173
write!(f, " END")?;
2171-
if *has_end_case {
2172-
write!(f, " CASE")?;
2174+
2175+
if let Token::Word(w) = &end.token {
2176+
if w.keyword == Keyword::CASE {
2177+
write!(f, " CASE")?;
2178+
}
21732179
}
21742180

21752181
Ok(())
@@ -2178,7 +2184,7 @@ impl fmt::Display for CaseStatement {
21782184

21792185
/// An `IF` statement.
21802186
///
2181-
/// Examples:
2187+
/// Example (BigQuery or Snowflake):
21822188
/// ```sql
21832189
/// IF TRUE THEN
21842190
/// SELECT 1;
@@ -2189,16 +2195,22 @@ impl fmt::Display for CaseStatement {
21892195
/// SELECT 4;
21902196
/// END IF
21912197
/// ```
2192-
///
21932198
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
21942199
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
2200+
///
2201+
/// Example (MSSQL):
2202+
/// ```sql
2203+
/// IF 1=1 SELECT 1 ELSE SELECT 2
2204+
/// ```
2205+
/// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16)
21952206
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
21962207
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21972208
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
21982209
pub struct IfStatement {
2199-
pub if_block: ConditionalStatements,
2200-
pub elseif_blocks: Vec<ConditionalStatements>,
2201-
pub else_block: Option<Vec<Statement>>,
2210+
pub if_block: ConditionalStatementBlock,
2211+
pub elseif_blocks: Vec<ConditionalStatementBlock>,
2212+
pub else_block: Option<ConditionalStatementBlock>,
2213+
pub end_token: Option<AttachedToken>,
22022214
}
22032215

22042216
impl fmt::Display for IfStatement {
@@ -2207,82 +2219,128 @@ impl fmt::Display for IfStatement {
22072219
if_block,
22082220
elseif_blocks,
22092221
else_block,
2222+
end_token,
22102223
} = self;
22112224

22122225
write!(f, "{if_block}")?;
22132226

2214-
if !elseif_blocks.is_empty() {
2215-
write!(f, " {}", display_separated(elseif_blocks, " "))?;
2227+
for elseif_block in elseif_blocks {
2228+
write!(f, " {elseif_block}")?;
22162229
}
22172230

22182231
if let Some(else_block) = else_block {
2219-
write!(f, " ELSE ")?;
2220-
format_statement_list(f, else_block)?;
2232+
write!(f, " {else_block}")?;
22212233
}
22222234

2223-
write!(f, " END IF")?;
2235+
if let Some(AttachedToken(end_token)) = end_token {
2236+
write!(f, " END {end_token}")?;
2237+
}
22242238

22252239
Ok(())
22262240
}
22272241
}
22282242

2229-
/// Represents a type of [ConditionalStatements]
2230-
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2231-
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2232-
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2233-
pub enum ConditionalStatementKind {
2234-
/// `WHEN <condition> THEN <statements>`
2235-
When,
2236-
/// `IF <condition> THEN <statements>`
2237-
If,
2238-
/// `ELSEIF <condition> THEN <statements>`
2239-
ElseIf,
2240-
}
2241-
22422243
/// A block within a [Statement::Case] or [Statement::If]-like statement
22432244
///
2244-
/// Examples:
2245+
/// Example 1:
22452246
/// ```sql
22462247
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
2248+
/// ```
22472249
///
2250+
/// Example 2:
2251+
/// ```sql
22482252
/// IF TRUE THEN SELECT 1; SELECT 2;
22492253
/// ```
2254+
///
2255+
/// Example 3:
2256+
/// ```sql
2257+
/// ELSE SELECT 1; SELECT 2;
2258+
/// ```
22502259
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
22512260
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22522261
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2253-
pub struct ConditionalStatements {
2254-
/// The condition expression.
2255-
pub condition: Expr,
2256-
/// Statement list of the `THEN` clause.
2257-
pub statements: Vec<Statement>,
2258-
pub kind: ConditionalStatementKind,
2262+
pub struct ConditionalStatementBlock {
2263+
pub start_token: AttachedToken,
2264+
pub condition: Option<Expr>,
2265+
pub then_token: Option<AttachedToken>,
2266+
pub conditional_statements: ConditionalStatements,
22592267
}
22602268

2261-
impl fmt::Display for ConditionalStatements {
2269+
impl ConditionalStatementBlock {
2270+
pub fn statements(&self) -> &Vec<Statement> {
2271+
self.conditional_statements.statements()
2272+
}
2273+
}
2274+
2275+
impl fmt::Display for ConditionalStatementBlock {
22622276
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2263-
let ConditionalStatements {
2264-
condition: expr,
2265-
statements,
2266-
kind,
2277+
let ConditionalStatementBlock {
2278+
start_token: AttachedToken(start_token),
2279+
condition,
2280+
then_token,
2281+
conditional_statements,
22672282
} = self;
22682283

2269-
let kind = match kind {
2270-
ConditionalStatementKind::When => "WHEN",
2271-
ConditionalStatementKind::If => "IF",
2272-
ConditionalStatementKind::ElseIf => "ELSEIF",
2273-
};
2284+
write!(f, "{start_token}")?;
2285+
2286+
if let Some(condition) = condition {
2287+
write!(f, " {condition}")?;
2288+
}
22742289

2275-
write!(f, "{kind} {expr} THEN")?;
2290+
if then_token.is_some() {
2291+
write!(f, " THEN")?;
2292+
}
22762293

2277-
if !statements.is_empty() {
2278-
write!(f, " ")?;
2279-
format_statement_list(f, statements)?;
2294+
if !conditional_statements.statements().is_empty() {
2295+
write!(f, " {conditional_statements}")?;
22802296
}
22812297

22822298
Ok(())
22832299
}
22842300
}
22852301

2302+
/// A list of statements in a [ConditionalStatementBlock].
2303+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2304+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2305+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2306+
pub enum ConditionalStatements {
2307+
/// SELECT 1; SELECT 2; SELECT 3; ...
2308+
Sequence { statements: Vec<Statement> },
2309+
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
2310+
BeginEnd {
2311+
begin_token: AttachedToken,
2312+
statements: Vec<Statement>,
2313+
end_token: AttachedToken,
2314+
},
2315+
}
2316+
2317+
impl ConditionalStatements {
2318+
pub fn statements(&self) -> &Vec<Statement> {
2319+
match self {
2320+
ConditionalStatements::Sequence { statements } => statements,
2321+
ConditionalStatements::BeginEnd { statements, .. } => statements,
2322+
}
2323+
}
2324+
}
2325+
2326+
impl fmt::Display for ConditionalStatements {
2327+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2328+
match self {
2329+
ConditionalStatements::Sequence { statements } => {
2330+
if !statements.is_empty() {
2331+
format_statement_list(f, statements)?;
2332+
}
2333+
Ok(())
2334+
}
2335+
ConditionalStatements::BeginEnd { statements, .. } => {
2336+
write!(f, "BEGIN ")?;
2337+
format_statement_list(f, statements)?;
2338+
write!(f, " END")
2339+
}
2340+
}
2341+
}
2342+
}
2343+
22862344
/// A `RAISE` statement.
22872345
///
22882346
/// Examples:

src/ast/spans.rs

+50-31
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,22 @@ use crate::tokenizer::Span;
2222

2323
use super::{
2424
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
25-
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement,
26-
CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements,
27-
ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
28-
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
29-
ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
30-
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
31-
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
32-
JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure,
33-
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
34-
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
35-
Query, RaiseStatement, RaiseStatementValue, ReferentialAction, RenameSelectItem,
36-
ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption,
37-
Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint,
38-
TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use,
39-
Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
25+
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
26+
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
27+
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
28+
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
29+
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
30+
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
31+
FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate,
32+
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
33+
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
34+
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
35+
PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue,
36+
ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
37+
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
38+
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
39+
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
40+
WildcardAdditionalOptions, With, WithFill,
4041
};
4142

4243
/// Given an iterator of spans, return the [Span::union] of all spans.
@@ -739,19 +740,14 @@ impl Spanned for CreateIndex {
739740
impl Spanned for CaseStatement {
740741
fn span(&self) -> Span {
741742
let CaseStatement {
742-
match_expr,
743-
when_blocks,
744-
else_block,
745-
has_end_case: _,
743+
case_token: AttachedToken(start),
744+
match_expr: _,
745+
when_blocks: _,
746+
else_block: _,
747+
end_case_token: AttachedToken(end),
746748
} = self;
747749

748-
union_spans(
749-
match_expr
750-
.iter()
751-
.map(|e| e.span())
752-
.chain(when_blocks.iter().map(|b| b.span()))
753-
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
754-
)
750+
union_spans([start.span, end.span].into_iter())
755751
}
756752
}
757753

@@ -761,25 +757,48 @@ impl Spanned for IfStatement {
761757
if_block,
762758
elseif_blocks,
763759
else_block,
760+
end_token,
764761
} = self;
765762

766763
union_spans(
767764
iter::once(if_block.span())
768765
.chain(elseif_blocks.iter().map(|b| b.span()))
769-
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
766+
.chain(else_block.as_ref().map(|b| b.span()))
767+
.chain(end_token.as_ref().map(|AttachedToken(t)| t.span)),
770768
)
771769
}
772770
}
773771

774772
impl Spanned for ConditionalStatements {
775773
fn span(&self) -> Span {
776-
let ConditionalStatements {
774+
match self {
775+
ConditionalStatements::Sequence { statements } => {
776+
union_spans(statements.iter().map(|s| s.span()))
777+
}
778+
ConditionalStatements::BeginEnd {
779+
begin_token: AttachedToken(start),
780+
statements: _,
781+
end_token: AttachedToken(end),
782+
} => union_spans([start.span, end.span].into_iter()),
783+
}
784+
}
785+
}
786+
787+
impl Spanned for ConditionalStatementBlock {
788+
fn span(&self) -> Span {
789+
let ConditionalStatementBlock {
790+
start_token: AttachedToken(start_token),
777791
condition,
778-
statements,
779-
kind: _,
792+
then_token,
793+
conditional_statements,
780794
} = self;
781795

782-
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
796+
union_spans(
797+
iter::once(start_token.span)
798+
.chain(condition.as_ref().map(|c| c.span()))
799+
.chain(then_token.as_ref().map(|AttachedToken(t)| t.span))
800+
.chain(iter::once(conditional_statements.span())),
801+
)
783802
}
784803
}
785804

0 commit comments

Comments
 (0)