Skip to content

Add support for MSSQL IF/ELSE statements. #1791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 112 additions & 54 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ use serde::{Deserialize, Serialize};
#[cfg(feature = "visitor")]
use sqlparser_derive::{Visit, VisitMut};

use crate::tokenizer::Span;
use crate::keywords::Keyword;
use crate::tokenizer::{Span, Token};

pub use self::data_type::{
ArrayElemTypeDef, BinaryLength, CharLengthUnits, CharacterLength, DataType, EnumMember,
Expand Down Expand Up @@ -2118,20 +2119,23 @@ pub enum Password {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseStatement {
/// The `CASE` token that starts the statement.
pub case_token: AttachedToken,
pub match_expr: Option<Expr>,
pub when_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
/// TRUE if the statement ends with `END CASE` (vs `END`).
pub has_end_case: bool,
pub when_blocks: Vec<ConditionalStatementBlock>,
pub else_block: Option<ConditionalStatementBlock>,
/// The last token of the statement (`END` or `CASE`).
pub end_case_token: AttachedToken,
}

impl fmt::Display for CaseStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let CaseStatement {
case_token: _,
match_expr,
when_blocks,
else_block,
has_end_case,
end_case_token: AttachedToken(end),
} = self;

write!(f, "CASE")?;
Expand All @@ -2145,13 +2149,15 @@ impl fmt::Display for CaseStatement {
}

if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
write!(f, " {else_block}")?;
}

write!(f, " END")?;
if *has_end_case {
write!(f, " CASE")?;

if let Token::Word(w) = &end.token {
if w.keyword == Keyword::CASE {
write!(f, " CASE")?;
}
}

Ok(())
Expand All @@ -2160,7 +2166,7 @@ impl fmt::Display for CaseStatement {

/// An `IF` statement.
///
/// Examples:
/// Example (BigQuery or Snowflake):
/// ```sql
/// IF TRUE THEN
/// SELECT 1;
Expand All @@ -2171,16 +2177,22 @@ impl fmt::Display for CaseStatement {
/// SELECT 4;
/// END IF
/// ```
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
///
/// Example (MSSQL):
/// ```sql
/// IF 1=1 SELECT 1 ELSE SELECT 2
/// ```
/// [MSSQL](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/if-else-transact-sql?view=sql-server-ver16)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct IfStatement {
pub if_block: ConditionalStatements,
pub elseif_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
pub if_block: ConditionalStatementBlock,
pub elseif_blocks: Vec<ConditionalStatementBlock>,
pub else_block: Option<ConditionalStatementBlock>,
pub end_token: Option<AttachedToken>,
}

impl fmt::Display for IfStatement {
Expand All @@ -2189,82 +2201,128 @@ impl fmt::Display for IfStatement {
if_block,
elseif_blocks,
else_block,
end_token,
} = self;

write!(f, "{if_block}")?;

if !elseif_blocks.is_empty() {
write!(f, " {}", display_separated(elseif_blocks, " "))?;
for elseif_block in elseif_blocks {
write!(f, " {elseif_block}")?;
}

if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
write!(f, " {else_block}")?;
}

write!(f, " END IF")?;
if let Some(AttachedToken(end_token)) = end_token {
write!(f, " END {end_token}")?;
}

Ok(())
}
}

/// Represents a type of [ConditionalStatements]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ConditionalStatementKind {
/// `WHEN <condition> THEN <statements>`
When,
/// `IF <condition> THEN <statements>`
If,
/// `ELSEIF <condition> THEN <statements>`
ElseIf,
}

/// A block within a [Statement::Case] or [Statement::If]-like statement
///
/// Examples:
/// Example 1:
/// ```sql
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
/// ```
///
/// Example 2:
/// ```sql
/// IF TRUE THEN SELECT 1; SELECT 2;
/// ```
///
/// Example 3:
/// ```sql
/// ELSE SELECT 1; SELECT 2;
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ConditionalStatements {
/// The condition expression.
pub condition: Expr,
/// Statement list of the `THEN` clause.
pub statements: Vec<Statement>,
pub kind: ConditionalStatementKind,
pub struct ConditionalStatementBlock {
pub start_token: AttachedToken,
pub condition: Option<Expr>,
pub then_token: Option<AttachedToken>,
pub conditional_statements: ConditionalStatements,
}

impl fmt::Display for ConditionalStatements {
impl ConditionalStatementBlock {
pub fn statements(&self) -> &Vec<Statement> {
self.conditional_statements.statements()
}
}

impl fmt::Display for ConditionalStatementBlock {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let ConditionalStatements {
condition: expr,
statements,
kind,
let ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements,
} = self;

let kind = match kind {
ConditionalStatementKind::When => "WHEN",
ConditionalStatementKind::If => "IF",
ConditionalStatementKind::ElseIf => "ELSEIF",
};
write!(f, "{start_token}")?;

if let Some(condition) = condition {
write!(f, " {condition}")?;
}

write!(f, "{kind} {expr} THEN")?;
if then_token.is_some() {
write!(f, " THEN")?;
}

if !statements.is_empty() {
write!(f, " ")?;
format_statement_list(f, statements)?;
if !conditional_statements.statements().is_empty() {
write!(f, " {conditional_statements}")?;
}

Ok(())
}
}

/// A list of statements in a [ConditionalStatementBlock].
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ConditionalStatements {
/// SELECT 1; SELECT 2; SELECT 3; ...
Sequence { statements: Vec<Statement> },
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
BeginEnd {
begin_token: AttachedToken,
statements: Vec<Statement>,
end_token: AttachedToken,
},
}

impl ConditionalStatements {
pub fn statements(&self) -> &Vec<Statement> {
match self {
ConditionalStatements::Sequence { statements } => statements,
ConditionalStatements::BeginEnd { statements, .. } => statements,
}
}
}

impl fmt::Display for ConditionalStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ConditionalStatements::Sequence { statements } => {
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
Ok(())
}
ConditionalStatements::BeginEnd { statements, .. } => {
write!(f, "BEGIN ")?;
format_statement_list(f, statements)?;
write!(f, " END")
}
}
}
}

/// A `RAISE` statement.
///
/// Examples:
Expand Down
79 changes: 48 additions & 31 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ use crate::tokenizer::Span;

use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement,
CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements,
ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure,
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
Query, RaiseStatement, RaiseStatementValue, ReferentialAction, RenameSelectItem,
ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption,
Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint,
TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use,
Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate,
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue,
ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WildcardAdditionalOptions, With, WithFill,
};

/// Given an iterator of spans, return the [Span::union] of all spans.
Expand Down Expand Up @@ -739,19 +740,12 @@ impl Spanned for CreateIndex {
impl Spanned for CaseStatement {
fn span(&self) -> Span {
let CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case: _,
case_token: AttachedToken(start),
end_case_token: AttachedToken(end),
..
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this and ConditionalStatement Spanned impl could we explicitly list the remaining fields so that the let match is exhaustive? So that we're forced to revisit this part of the code if new fields are added/modified

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I did that! Thanks again for the review.

} = self;

union_spans(
match_expr
.iter()
.map(|e| e.span())
.chain(when_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
)
union_spans([start.span, end.span].into_iter())
}
}

Expand All @@ -761,25 +755,48 @@ impl Spanned for IfStatement {
if_block,
elseif_blocks,
else_block,
end_token,
} = self;

union_spans(
iter::once(if_block.span())
.chain(elseif_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
.chain(else_block.as_ref().map(|b| b.span()))
.chain(end_token.as_ref().map(|AttachedToken(t)| t.span)),
)
}
}

impl Spanned for ConditionalStatements {
fn span(&self) -> Span {
let ConditionalStatements {
match self {
ConditionalStatements::Sequence { statements } => {
union_spans(statements.iter().map(|s| s.span()))
}
ConditionalStatements::BeginEnd {
begin_token: AttachedToken(start),
end_token: AttachedToken(end),
..
} => union_spans([start.span, end.span].into_iter()),
}
}
}

impl Spanned for ConditionalStatementBlock {
fn span(&self) -> Span {
let ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
statements,
kind: _,
then_token,
conditional_statements,
} = self;

union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
union_spans(
iter::once(start_token.span)
.chain(condition.as_ref().map(|c| c.span()))
.chain(then_token.as_ref().map(|AttachedToken(t)| t.span))
.chain(iter::once(conditional_statements.span())),
)
}
}

Expand Down
Loading