diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2a687800a..2b7767ab1 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1532,6 +1532,14 @@ pub enum DeclareAssignment { /// DECLARE c1 CURSOR FOR res /// ``` For(Box), + + /// Expression via the `=` syntax. + /// + /// Example: + /// ```sql + /// DECLARE @variable AS INT = 100 + /// ``` + MsSqlAssignment(Box), } impl fmt::Display for DeclareAssignment { @@ -1546,6 +1554,9 @@ impl fmt::Display for DeclareAssignment { DeclareAssignment::DuckAssignment(expr) => { write!(f, ":= {expr}") } + DeclareAssignment::MsSqlAssignment(expr) => { + write!(f, "= {expr}") + } DeclareAssignment::For(expr) => { write!(f, "FOR {expr}") } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 7110402a4..9e7733013 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4244,6 +4244,9 @@ impl<'a> Parser<'a> { if dialect_of!(self is SnowflakeDialect) { return self.parse_snowflake_declare(); } + if dialect_of!(self is MsSqlDialect) { + return self.parse_mssql_declare(); + } let name = self.parse_identifier(false)?; @@ -4457,6 +4460,69 @@ impl<'a> Parser<'a> { Ok(Statement::Declare { stmts }) } + /// Parse a [MsSql] `DECLARE` statement. + /// + /// Syntax: + /// ```text + /// DECLARE + // { + // { @local_variable [AS] data_type [ = value ] } + // | { @cursor_variable_name CURSOR } + // } [ ,...n ] + /// ``` + /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/declare-local-variable-transact-sql?view=sql-server-ver16 + pub fn parse_mssql_declare(&mut self) -> Result { + let mut stmts = vec![]; + + loop { + let name = { + let ident = self.parse_identifier(false)?; + if !ident.value.starts_with('@') { + Err(ParserError::TokenizerError( + "Invalid MsSql variable declaration.".to_string(), + )) + } else { + Ok(ident) + } + }?; + + let (declare_type, data_type) = match self.peek_token().token { + Token::Word(w) => match w.keyword { + Keyword::CURSOR => { + self.next_token(); + (Some(DeclareType::Cursor), None) + } + Keyword::AS => { + self.next_token(); + (None, Some(self.parse_data_type()?)) + } + _ => (None, Some(self.parse_data_type()?)), + }, + _ => (None, Some(self.parse_data_type()?)), + }; + + let assignment = self.parse_mssql_variable_declaration_expression()?; + + stmts.push(Declare { + names: vec![name], + data_type, + assignment, + declare_type, + binary: None, + sensitive: None, + scroll: None, + hold: None, + for_query: None, + }); + + if self.next_token() != Token::Comma { + break; + } + } + + Ok(Statement::Declare { stmts }) + } + /// Parses the assigned expression in a variable declaration. /// /// Syntax: @@ -4482,6 +4548,26 @@ impl<'a> Parser<'a> { }) } + /// Parses the assigned expression in a variable declaration. + /// + /// Syntax: + /// ```text + /// [ = ] + /// ``` + pub fn parse_mssql_variable_declaration_expression( + &mut self, + ) -> Result, ParserError> { + Ok(match self.peek_token().token { + Token::Eq => { + self.next_token(); // Skip `=` + Some(DeclareAssignment::MsSqlAssignment(Box::new( + self.parse_expr()?, + ))) + } + _ => None, + }) + } + // FETCH [ direction { FROM | IN } ] cursor INTO target; pub fn parse_fetch_statement(&mut self) -> Result { let direction = if self.parse_keyword(Keyword::NEXT) { diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 353fb85d5..baff72f10 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -16,11 +16,15 @@ #[macro_use] mod test_utils; + use test_utils::*; +use sqlparser::ast::DataType::{Int, Text}; +use sqlparser::ast::DeclareAssignment::MsSqlAssignment; +use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, MsSqlDialect}; -use sqlparser::parser::ParserError; +use sqlparser::parser::{Parser, ParserError}; #[test] fn parse_mssql_identifiers() { @@ -539,6 +543,64 @@ fn parse_substring_in_select() { } } +#[test] +fn parse_mssql_declare() { + let sql = "DECLARE @foo CURSOR, @bar INT, @baz AS TEXT = 'foobar';"; + let ast = Parser::parse_sql(&MsSqlDialect {}, sql).unwrap(); + + assert_eq!( + vec![Statement::Declare { + stmts: vec![ + Declare { + names: vec![Ident { + value: "@foo".to_string(), + quote_style: None + }], + data_type: None, + assignment: None, + declare_type: Some(DeclareType::Cursor), + binary: None, + sensitive: None, + scroll: None, + hold: None, + for_query: None + }, + Declare { + names: vec![Ident { + value: "@bar".to_string(), + quote_style: None + }], + data_type: Some(Int(None)), + assignment: None, + declare_type: None, + binary: None, + sensitive: None, + scroll: None, + hold: None, + for_query: None + }, + Declare { + names: vec![Ident { + value: "@baz".to_string(), + quote_style: None + }], + data_type: Some(Text), + assignment: Some(MsSqlAssignment(Box::new(Expr::Value(SingleQuotedString( + "foobar".to_string() + ))))), + declare_type: None, + binary: None, + sensitive: None, + scroll: None, + hold: None, + for_query: None + } + ] + }], + ast + ); +} + fn ms() -> TestedDialects { TestedDialects { dialects: vec![Box::new(MsSqlDialect {})],