Skip to content

Commit a1395e0

Browse files
devanbenzJichaoS
authored andcommitted
feat: add DECLARE parsing for mssql (apache#1235)
1 parent 7a791f3 commit a1395e0

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

src/ast/mod.rs

+11
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,14 @@ pub enum DeclareAssignment {
15661566
/// DECLARE c1 CURSOR FOR res
15671567
/// ```
15681568
For(Box<Expr>),
1569+
1570+
/// Expression via the `=` syntax.
1571+
///
1572+
/// Example:
1573+
/// ```sql
1574+
/// DECLARE @variable AS INT = 100
1575+
/// ```
1576+
MsSqlAssignment(Box<Expr>),
15691577
}
15701578

15711579
impl fmt::Display for DeclareAssignment {
@@ -1580,6 +1588,9 @@ impl fmt::Display for DeclareAssignment {
15801588
DeclareAssignment::DuckAssignment(expr) => {
15811589
write!(f, ":= {expr}")
15821590
}
1591+
DeclareAssignment::MsSqlAssignment(expr) => {
1592+
write!(f, "= {expr}")
1593+
}
15831594
DeclareAssignment::For(expr) => {
15841595
write!(f, "FOR {expr}")
15851596
}

src/parser/mod.rs

+86
Original file line numberDiff line numberDiff line change
@@ -4277,6 +4277,9 @@ impl<'a> Parser<'a> {
42774277
if dialect_of!(self is SnowflakeDialect) {
42784278
return self.parse_snowflake_declare();
42794279
}
4280+
if dialect_of!(self is MsSqlDialect) {
4281+
return self.parse_mssql_declare();
4282+
}
42804283

42814284
let name = self.parse_identifier(false)?;
42824285

@@ -4490,6 +4493,69 @@ impl<'a> Parser<'a> {
44904493
Ok(Statement::Declare { stmts })
44914494
}
44924495

4496+
/// Parse a [MsSql] `DECLARE` statement.
4497+
///
4498+
/// Syntax:
4499+
/// ```text
4500+
/// DECLARE
4501+
// {
4502+
// { @local_variable [AS] data_type [ = value ] }
4503+
// | { @cursor_variable_name CURSOR }
4504+
// } [ ,...n ]
4505+
/// ```
4506+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/declare-local-variable-transact-sql?view=sql-server-ver16
4507+
pub fn parse_mssql_declare(&mut self) -> Result<Statement, ParserError> {
4508+
let mut stmts = vec![];
4509+
4510+
loop {
4511+
let name = {
4512+
let ident = self.parse_identifier(false)?;
4513+
if !ident.value.starts_with('@') {
4514+
Err(ParserError::TokenizerError(
4515+
"Invalid MsSql variable declaration.".to_string(),
4516+
))
4517+
} else {
4518+
Ok(ident)
4519+
}
4520+
}?;
4521+
4522+
let (declare_type, data_type) = match self.peek_token().token {
4523+
Token::Word(w) => match w.keyword {
4524+
Keyword::CURSOR => {
4525+
self.next_token();
4526+
(Some(DeclareType::Cursor), None)
4527+
}
4528+
Keyword::AS => {
4529+
self.next_token();
4530+
(None, Some(self.parse_data_type()?))
4531+
}
4532+
_ => (None, Some(self.parse_data_type()?)),
4533+
},
4534+
_ => (None, Some(self.parse_data_type()?)),
4535+
};
4536+
4537+
let assignment = self.parse_mssql_variable_declaration_expression()?;
4538+
4539+
stmts.push(Declare {
4540+
names: vec![name],
4541+
data_type,
4542+
assignment,
4543+
declare_type,
4544+
binary: None,
4545+
sensitive: None,
4546+
scroll: None,
4547+
hold: None,
4548+
for_query: None,
4549+
});
4550+
4551+
if self.next_token() != Token::Comma {
4552+
break;
4553+
}
4554+
}
4555+
4556+
Ok(Statement::Declare { stmts })
4557+
}
4558+
44934559
/// Parses the assigned expression in a variable declaration.
44944560
///
44954561
/// Syntax:
@@ -4515,6 +4581,26 @@ impl<'a> Parser<'a> {
45154581
})
45164582
}
45174583

4584+
/// Parses the assigned expression in a variable declaration.
4585+
///
4586+
/// Syntax:
4587+
/// ```text
4588+
/// [ = <expression>]
4589+
/// ```
4590+
pub fn parse_mssql_variable_declaration_expression(
4591+
&mut self,
4592+
) -> Result<Option<DeclareAssignment>, ParserError> {
4593+
Ok(match self.peek_token().token {
4594+
Token::Eq => {
4595+
self.next_token(); // Skip `=`
4596+
Some(DeclareAssignment::MsSqlAssignment(Box::new(
4597+
self.parse_expr()?,
4598+
)))
4599+
}
4600+
_ => None,
4601+
})
4602+
}
4603+
45184604
// FETCH [ direction { FROM | IN } ] cursor INTO target;
45194605
pub fn parse_fetch_statement(&mut self) -> Result<Statement, ParserError> {
45204606
let direction = if self.parse_keyword(Keyword::NEXT) {

tests/sqlparser_mssql.rs

+63-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
1717
#[macro_use]
1818
mod test_utils;
19+
1920
use test_utils::*;
2021

22+
use sqlparser::ast::DataType::{Int, Text};
23+
use sqlparser::ast::DeclareAssignment::MsSqlAssignment;
24+
use sqlparser::ast::Value::SingleQuotedString;
2125
use sqlparser::ast::*;
2226
use sqlparser::dialect::{GenericDialect, MsSqlDialect};
23-
use sqlparser::parser::ParserError;
27+
use sqlparser::parser::{Parser, ParserError};
2428

2529
#[test]
2630
fn parse_mssql_identifiers() {
@@ -539,6 +543,64 @@ fn parse_substring_in_select() {
539543
}
540544
}
541545

546+
#[test]
547+
fn parse_mssql_declare() {
548+
let sql = "DECLARE @foo CURSOR, @bar INT, @baz AS TEXT = 'foobar';";
549+
let ast = Parser::parse_sql(&MsSqlDialect {}, sql).unwrap();
550+
551+
assert_eq!(
552+
vec![Statement::Declare {
553+
stmts: vec![
554+
Declare {
555+
names: vec![Ident {
556+
value: "@foo".to_string(),
557+
quote_style: None
558+
}],
559+
data_type: None,
560+
assignment: None,
561+
declare_type: Some(DeclareType::Cursor),
562+
binary: None,
563+
sensitive: None,
564+
scroll: None,
565+
hold: None,
566+
for_query: None
567+
},
568+
Declare {
569+
names: vec![Ident {
570+
value: "@bar".to_string(),
571+
quote_style: None
572+
}],
573+
data_type: Some(Int(None)),
574+
assignment: None,
575+
declare_type: None,
576+
binary: None,
577+
sensitive: None,
578+
scroll: None,
579+
hold: None,
580+
for_query: None
581+
},
582+
Declare {
583+
names: vec![Ident {
584+
value: "@baz".to_string(),
585+
quote_style: None
586+
}],
587+
data_type: Some(Text),
588+
assignment: Some(MsSqlAssignment(Box::new(Expr::Value(SingleQuotedString(
589+
"foobar".to_string()
590+
))))),
591+
declare_type: None,
592+
binary: None,
593+
sensitive: None,
594+
scroll: None,
595+
hold: None,
596+
for_query: None
597+
}
598+
]
599+
}],
600+
ast
601+
);
602+
}
603+
542604
fn ms() -> TestedDialects {
543605
TestedDialects {
544606
dialects: vec![Box::new(MsSqlDialect {})],

0 commit comments

Comments
 (0)