From f3af75ffdb75e7e4653a378ef928a9bc3b55d0d9 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Thu, 10 Apr 2025 13:25:35 -0400 Subject: [PATCH 01/16] Add basic `CREATE FUNCTION` support for SQL Server - in this dialect, functions can have statement(s) bodies like stored procedures (including `BEGIN`..`END`) - functions must end with `RETURN`, so a corresponding statement type is also introduced --- src/ast/ddl.rs | 6 ++ src/ast/mod.rs | 60 +++++++++++++++++ src/ast/spans.rs | 1 + src/parser/mod.rs | 83 +++++++++++++++++++++-- tests/sqlparser_hive.rs | 4 +- tests/sqlparser_mssql.rs | 139 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 285 insertions(+), 8 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 000ab3a4f..458d5ff97 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -2272,6 +2272,12 @@ impl fmt::Display for CreateFunction { if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body { write!(f, " AS {function_body}")?; } + if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body { + write!(f, " AS")?; + write!(f, " BEGIN")?; + write!(f, " {}", display_separated(statements, "; "))?; + write!(f, " END")?; + } Ok(()) } } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index ab3be35c1..c4f8f00ab 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3614,6 +3614,7 @@ pub enum Statement { /// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction) /// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html) /// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement) + /// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql) CreateFunction(CreateFunction), /// CREATE TRIGGER /// @@ -4060,6 +4061,12 @@ pub enum Statement { /// /// See: Print(PrintStatement), + /// ```sql + /// RETURN [ expression ] + /// ``` + /// + /// See [ReturnStatement] + Return(ReturnStatement), } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -5752,6 +5759,7 @@ impl fmt::Display for Statement { Ok(()) } Statement::Print(s) => write!(f, "{s}"), + Statement::Return(r) => write!(f, "{r}"), Statement::List(command) => write!(f, "LIST {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"), } @@ -8354,6 +8362,7 @@ impl fmt::Display for FunctionDeterminismSpecifier { /// /// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11 /// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html +/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -8382,6 +8391,22 @@ pub enum CreateFunctionBody { /// /// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11 AsAfterOptions(Expr), + /// Function body with statements before the `RETURN` keyword. + /// + /// Example: + /// ```sql + /// CREATE FUNCTION my_scalar_udf(a INT, b INT) + /// RETURNS INT + /// AS + /// BEGIN + /// DECLARE c INT; + /// SET c = a + b; + /// RETURN c; + /// END + /// ``` + /// + /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql + MultiStatement(Vec), /// Function body expression using the 'RETURN' keyword. /// /// Example: @@ -9230,6 +9255,41 @@ impl fmt::Display for PrintStatement { } } +/// Return (MsSql) +/// +/// for Functions: +/// RETURN scalar_expression +/// +/// See +/// +/// for Triggers: +/// RETURN +/// +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ReturnStatement { + pub value: Option, +} + +impl fmt::Display for ReturnStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.value { + Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr), + None => write!(f, "RETURN"), + } + } +} + +/// Variants of a `RETURN` statement +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ReturnStatementValue { + Expr(Expr), +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/ast/spans.rs b/src/ast/spans.rs index a241fdf4d..7b2652fa4 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -520,6 +520,7 @@ impl Spanned for Statement { Statement::RenameTable { .. } => Span::empty(), Statement::RaisError { .. } => Span::empty(), Statement::Print { .. } => Span::empty(), + Statement::Return { .. } => Span::empty(), Statement::List(..) | Statement::Remove(..) => Span::empty(), } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a9ddd1837..53744d70f 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -577,13 +577,7 @@ impl<'a> Parser<'a> { Keyword::GRANT => self.parse_grant(), Keyword::REVOKE => self.parse_revoke(), Keyword::START => self.parse_start_transaction(), - // `BEGIN` is a nonstandard but common alias for the - // standard `START TRANSACTION` statement. It is supported - // by at least PostgreSQL and MySQL. Keyword::BEGIN => self.parse_begin(), - // `END` is a nonstandard but common alias for the - // standard `COMMIT TRANSACTION` statement. It is supported - // by PostgreSQL. Keyword::END => self.parse_end(), Keyword::SAVEPOINT => self.parse_savepoint(), Keyword::RELEASE => self.parse_release(), @@ -618,6 +612,7 @@ impl<'a> Parser<'a> { // `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(), Keyword::PRINT => self.parse_print(), + Keyword::RETURN => self.parse_return(), _ => self.expected("an SQL statement", next_token), }, Token::LParen => { @@ -4880,6 +4875,8 @@ impl<'a> Parser<'a> { self.parse_create_macro(or_replace, temporary) } else if dialect_of!(self is BigQueryDialect) { self.parse_bigquery_create_function(or_replace, temporary) + } else if dialect_of!(self is MsSqlDialect) { + self.parse_mssql_create_function(or_replace, temporary) } else { self.prev_token(); self.expected("an object type after CREATE", self.peek_token()) @@ -5134,6 +5131,72 @@ impl<'a> Parser<'a> { })) } + /// Parse `CREATE FUNCTION` for [MsSql] + /// + /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql + fn parse_mssql_create_function( + &mut self, + or_replace: bool, + temporary: bool, + ) -> Result { + let name = self.parse_object_name(false)?; + + let parse_function_param = + |parser: &mut Parser| -> Result { + let name = parser.parse_identifier()?; + let data_type = parser.parse_data_type()?; + Ok(OperateFunctionArg { + mode: None, + name: Some(name), + data_type, + default_expr: None, + }) + }; + self.expect_token(&Token::LParen)?; + let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?; + self.expect_token(&Token::RParen)?; + + let return_type = if self.parse_keyword(Keyword::RETURNS) { + Some(self.parse_data_type()?) + } else { + return parser_err!("Expected RETURNS keyword", self.peek_token().span.start); + }; + + self.expect_keyword_is(Keyword::AS)?; + self.expect_keyword_is(Keyword::BEGIN)?; + let mut result = self.parse_statements()?; + // note: `parse_statements` will consume the `END` token & produce a Commit statement... + if let Some(Statement::Commit { + chain, + end, + modifier, + }) = result.last() + { + if *chain == false && *end == true && *modifier == None { + result = result[..result.len() - 1].to_vec(); + } + } + let function_body = Some(CreateFunctionBody::MultiStatement(result)); + + Ok(Statement::CreateFunction(CreateFunction { + or_replace, + temporary, + if_not_exists: false, + name, + args: Some(args), + return_type, + function_body, + language: None, + determinism_specifier: None, + options: None, + remote_connection: None, + using: None, + behavior: None, + called_on_null: None, + parallel: None, + })) + } + fn parse_function_arg(&mut self) -> Result { let mode = if self.parse_keyword(Keyword::IN) { Some(ArgMode::In) @@ -15064,6 +15127,14 @@ impl<'a> Parser<'a> { })) } + /// Parse [Statement::Return] + fn parse_return(&mut self) -> Result { + let expr = self.parse_expr()?; + Ok(Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(expr)), + })) + } + /// Consume the parser and return its underlying token buffer pub fn into_tokens(self) -> Vec { self.tokens diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 2af93db7d..9b0430947 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -25,7 +25,7 @@ use sqlparser::ast::{ Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr, OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value, }; -use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; +use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect}; use sqlparser::parser::ParserError; use sqlparser::test_utils::*; @@ -423,7 +423,7 @@ fn parse_create_function() { } // Test error in dialect that doesn't support parsing CREATE FUNCTION - let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]); + let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]); assert_eq!( unsupported_dialects.parse_sql_statements(sql).unwrap_err(), diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 2786384b3..2b148e8d6 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -187,6 +187,145 @@ fn parse_mssql_create_procedure() { let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END"); } +#[test] +fn parse_create_function() { + let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1 END"; + assert_eq!( + ms().verified_stmt(return_expression_function), + sqlparser::ast::Statement::CreateFunction(CreateFunction { + or_replace: false, + temporary: false, + if_not_exists: false, + name: ObjectName::from(vec![Ident { + value: "some_scalar_udf".into(), + quote_style: None, + span: Span::empty(), + }]), + args: Some(vec![ + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@foo".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Int(None), + default_expr: None, + }, + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@bar".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 256, + unit: None + })), + default_expr: None, + }, + ]), + return_type: Some(DataType::Int(None)), + function_body: Some(CreateFunctionBody::MultiStatement(vec![ + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))), + }), + ])), + behavior: None, + called_on_null: None, + parallel: None, + using: None, + language: None, + determinism_specifier: None, + options: None, + remote_connection: None, + }), + ); + + let multi_statement_function = "\ + CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \ + RETURNS INT \ + AS \ + BEGIN \ + SET @foo = @foo + 1; \ + RETURN @foo \ + END\ + "; + assert_eq!( + ms().verified_stmt(multi_statement_function), + sqlparser::ast::Statement::CreateFunction(CreateFunction { + or_replace: false, + temporary: false, + if_not_exists: false, + name: ObjectName::from(vec![Ident { + value: "some_scalar_udf".into(), + quote_style: None, + span: Span::empty(), + }]), + args: Some(vec![ + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@foo".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Int(None), + default_expr: None, + }, + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@bar".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 256, + unit: None + })), + default_expr: None, + }, + ]), + return_type: Some(DataType::Int(None)), + function_body: Some(CreateFunctionBody::MultiStatement(vec![ + Statement::Set(Set::SingleAssignment { + scope: None, + hivevar: false, + variable: ObjectName::from(vec!["@foo".into()]), + values: vec![sqlparser::ast::Expr::BinaryOp { + left: Box::new(sqlparser::ast::Expr::Identifier(Ident { + value: "@foo".to_string(), + quote_style: None, + span: Span::empty(), + })), + op: sqlparser::ast::BinaryOperator::Plus, + right: Box::new(Expr::Value( + (Value::Number("1".into(), false)).with_empty_span() + )), + }], + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { + value: "@foo".into(), + quote_style: None, + span: Span::empty(), + }))), + }), + ])), + behavior: None, + called_on_null: None, + parallel: None, + using: None, + language: None, + determinism_specifier: None, + options: None, + remote_connection: None, + }), + ); +} + #[test] fn parse_mssql_apply_join() { let _ = ms_and_generic().verified_only_select( From 02ecfe53ac20dd5ea6c14a302105c6793f3d0f8b Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Thu, 10 Apr 2025 15:46:45 -0400 Subject: [PATCH 02/16] Add `OR ALTER` support for `CREATE FUNCTION` --- src/ast/ddl.rs | 7 ++- src/parser/mod.rs | 10 ++++- tests/sqlparser_bigquery.rs | 1 + tests/sqlparser_mssql.rs | 88 +++++++++++++++++++++++++++++++++++++ tests/sqlparser_postgres.rs | 2 + 5 files changed, 105 insertions(+), 3 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 458d5ff97..dcb6b4d67 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct CreateFunction { + /// True if this is a `CREATE OR ALTER FUNCTION` statement + /// + /// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter) + pub or_alter: bool, pub or_replace: bool, pub temporary: bool, pub if_not_exists: bool, @@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}", + "CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}", name = self.name, temp = if self.temporary { "TEMPORARY " } else { "" }, + or_alter = if self.or_alter { "OR ALTER " } else { "" }, or_replace = if self.or_replace { "OR REPLACE " } else { "" }, if_not_exists = if self.if_not_exists { "IF NOT EXISTS " diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 53744d70f..62e4f6f59 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4555,7 +4555,7 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table(or_replace) } else if self.parse_keyword(Keyword::FUNCTION) { - self.parse_create_function(or_replace, temporary) + self.parse_create_function(or_alter, or_replace, temporary) } else if self.parse_keyword(Keyword::TRIGGER) { self.parse_create_trigger(or_replace, false) } else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) { @@ -4864,6 +4864,7 @@ impl<'a> Parser<'a> { pub fn parse_create_function( &mut self, + or_alter: bool, or_replace: bool, temporary: bool, ) -> Result { @@ -4876,7 +4877,7 @@ impl<'a> Parser<'a> { } else if dialect_of!(self is BigQueryDialect) { self.parse_bigquery_create_function(or_replace, temporary) } else if dialect_of!(self is MsSqlDialect) { - self.parse_mssql_create_function(or_replace, temporary) + self.parse_mssql_create_function(or_alter, or_replace, temporary) } else { self.prev_token(); self.expected("an object type after CREATE", self.peek_token()) @@ -4991,6 +4992,7 @@ impl<'a> Parser<'a> { } Ok(Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace, temporary, name, @@ -5024,6 +5026,7 @@ impl<'a> Parser<'a> { let using = self.parse_optional_create_function_using()?; Ok(Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace, temporary, name, @@ -5113,6 +5116,7 @@ impl<'a> Parser<'a> { }; Ok(Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace, temporary, if_not_exists, @@ -5136,6 +5140,7 @@ impl<'a> Parser<'a> { /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql fn parse_mssql_create_function( &mut self, + or_alter: bool, or_replace: bool, temporary: bool, ) -> Result { @@ -5179,6 +5184,7 @@ impl<'a> Parser<'a> { let function_body = Some(CreateFunctionBody::MultiStatement(result)); Ok(Statement::CreateFunction(CreateFunction { + or_alter, or_replace, temporary, if_not_exists: false, diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 5eb30d15c..416d2e435 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -2134,6 +2134,7 @@ fn test_bigquery_create_function() { assert_eq!( stmt, Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace: true, temporary: true, if_not_exists: false, diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 2b148e8d6..01a858c40 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -193,6 +193,7 @@ fn parse_create_function() { assert_eq!( ms().verified_stmt(return_expression_function), sqlparser::ast::Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace: false, temporary: false, if_not_exists: false, @@ -255,6 +256,93 @@ fn parse_create_function() { assert_eq!( ms().verified_stmt(multi_statement_function), sqlparser::ast::Statement::CreateFunction(CreateFunction { + or_alter: false, + or_replace: false, + temporary: false, + if_not_exists: false, + name: ObjectName::from(vec![Ident { + value: "some_scalar_udf".into(), + quote_style: None, + span: Span::empty(), + }]), + args: Some(vec![ + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@foo".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Int(None), + default_expr: None, + }, + OperateFunctionArg { + mode: None, + name: Some(Ident { + value: "@bar".into(), + quote_style: None, + span: Span::empty(), + }), + data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 256, + unit: None + })), + default_expr: None, + }, + ]), + return_type: Some(DataType::Int(None)), + function_body: Some(CreateFunctionBody::MultiStatement(vec![ + Statement::Set(Set::SingleAssignment { + scope: None, + hivevar: false, + variable: ObjectName::from(vec!["@foo".into()]), + values: vec![sqlparser::ast::Expr::BinaryOp { + left: Box::new(sqlparser::ast::Expr::Identifier(Ident { + value: "@foo".to_string(), + quote_style: None, + span: Span::empty(), + })), + op: sqlparser::ast::BinaryOperator::Plus, + right: Box::new(Expr::Value( + (Value::Number("1".into(), false)).with_empty_span() + )), + }], + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { + value: "@foo".into(), + quote_style: None, + span: Span::empty(), + }))), + }), + ])), + behavior: None, + called_on_null: None, + parallel: None, + using: None, + language: None, + determinism_specifier: None, + options: None, + remote_connection: None, + }), + ); +} + +#[test] +fn parse_mssql_create_function() { + let create_or_alter_function = "\ + CREATE OR ALTER FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \ + RETURNS INT \ + AS \ + BEGIN \ + SET @foo = @foo + 1; \ + RETURN @foo \ + END\ + "; + assert_eq!( + ms().verified_stmt(create_or_alter_function), + sqlparser::ast::Statement::CreateFunction(CreateFunction { + or_alter: true, or_replace: false, temporary: false, if_not_exists: false, diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 098d4b1c4..27fc7fa17 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -4104,6 +4104,7 @@ fn parse_create_function() { assert_eq!( pg_and_generic().verified_stmt(sql), Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace: false, temporary: false, name: ObjectName::from(vec![Ident::new("add")]), @@ -5485,6 +5486,7 @@ fn parse_trigger_related_functions() { assert_eq!( create_function, Statement::CreateFunction(CreateFunction { + or_alter: false, or_replace: false, temporary: false, if_not_exists: false, From dff760ffb9a58e4d39e72107675c3db0a77e392d Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Mon, 14 Apr 2025 15:41:41 -0400 Subject: [PATCH 03/16] Add test for conditionals in multi statement functions --- tests/sqlparser_mssql.rs | 86 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 01a858c40..65450977f 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -23,7 +23,8 @@ mod test_utils; use helpers::attached_token::AttachedToken; -use sqlparser::tokenizer::{Location, Span}; +use sqlparser::keywords::Keyword; +use sqlparser::tokenizer::{Location, Span, TokenWithSpan}; use test_utils::*; use sqlparser::ast::DataType::{Int, Text, Varbinary}; @@ -326,6 +327,89 @@ fn parse_create_function() { remote_connection: None, }), ); + + let create_function_with_conditional = r#" + CREATE FUNCTION some_scalar_udf() + RETURNS INT + AS + BEGIN + IF 1=2 + BEGIN + RETURN 1; + END + + RETURN 0; + END + "#; + let create_stmt = ms().one_statement_parses_to(create_function_with_conditional, ""); + assert_eq!( + create_stmt, + Statement::CreateFunction(CreateFunction { + or_alter: false, + or_replace: false, + temporary: false, + if_not_exists: false, + name: ObjectName::from(vec![Ident { + value: "some_scalar_udf".into(), + quote_style: None, + span: Span::empty(), + }]), + args: Some(vec![]), + return_type: Some(DataType::Int(None)), + function_body: Some(CreateFunctionBody::MultiStatement(vec![ + Statement::If(IfStatement { + if_block: ConditionalStatementBlock { + start_token: AttachedToken(TokenWithSpan::wrap( + sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { + value: "IF".to_string(), + quote_style: None, + keyword: Keyword::IF + }) + )), + condition: Some(Expr::BinaryOp { + left: Box::new(Expr::Value( + Value::Number("1".to_string(), false).with_empty_span() + )), + op: sqlparser::ast::BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2".to_string(), false).with_empty_span())), + }), + then_token: None, + conditional_statements: ConditionalStatements::BeginEnd { + begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + }))), + statements: vec![Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))), + })], + end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { + value: "END".to_string(), + quote_style: None, + keyword: Keyword::END + }))), + }, + }, + elseif_blocks: vec![], + else_block: None, + end_token: None, + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value( + (number("0")).with_empty_span() + ))), + }), + ])), + behavior: None, + called_on_null: None, + parallel: None, + using: None, + language: None, + determinism_specifier: None, + options: None, + remote_connection: None, + }) + ); } #[test] From 7c3e79b6dd3ec2cbd90242fe0aabdb641159dcc2 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Tue, 15 Apr 2025 15:44:08 -0400 Subject: [PATCH 04/16] Refactor `BeginEndStatements` into a reusable struct, then use for functions - this lets us disacard the former (unfortunate, bespoke) multi statement parsing because we can just use `parse_statement_list` - however, `parse_statement_list` also needed a small change to allow subsequent statements to come after the final `END` --- src/ast/ddl.rs | 6 +- src/ast/mod.rs | 42 +++++-- src/ast/spans.rs | 8 +- src/dialect/mssql.rs | 12 +- src/parser/mod.rs | 36 +++--- tests/sqlparser_mssql.rs | 258 +++++++++++++++++++++++++-------------- 6 files changed, 226 insertions(+), 136 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index dcb6b4d67..757c5a1d2 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -2277,11 +2277,9 @@ impl fmt::Display for CreateFunction { if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body { write!(f, " AS {function_body}")?; } - if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body { + if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body { write!(f, " AS")?; - write!(f, " BEGIN")?; - write!(f, " {}", display_separated(statements, "; "))?; - write!(f, " END")?; + write!(f, " {}", bes)?; } Ok(()) } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index c4f8f00ab..7454e8ded 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2292,18 +2292,14 @@ pub enum ConditionalStatements { /// SELECT 1; SELECT 2; SELECT 3; ... Sequence { statements: Vec }, /// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END - BeginEnd { - begin_token: AttachedToken, - statements: Vec, - end_token: AttachedToken, - }, + BeginEnd(BeginEndStatements), } impl ConditionalStatements { pub fn statements(&self) -> &Vec { match self { ConditionalStatements::Sequence { statements } => statements, - ConditionalStatements::BeginEnd { statements, .. } => statements, + ConditionalStatements::BeginEnd(bes) => &bes.statements, } } } @@ -2317,12 +2313,34 @@ impl fmt::Display for ConditionalStatements { } Ok(()) } - ConditionalStatements::BeginEnd { statements, .. } => { - write!(f, "BEGIN ")?; - format_statement_list(f, statements)?; - write!(f, " END") - } + ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes), + } + } +} + +/// A shared representation of `BEGIN`, multiple statements, and `END` tokens. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct BeginEndStatements { + pub begin_token: AttachedToken, + pub statements: Vec, + pub end_token: AttachedToken, +} + +impl fmt::Display for BeginEndStatements { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let BeginEndStatements { + begin_token: AttachedToken(begin_token), + statements, + end_token: AttachedToken(end_token), + } = self; + + write!(f, "{begin_token} ")?; + if !statements.is_empty() { + format_statement_list(f, statements)?; } + write!(f, " {end_token}") } } @@ -8406,7 +8424,7 @@ pub enum CreateFunctionBody { /// ``` /// /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql - MultiStatement(Vec), + AsBeginEnd(BeginEndStatements), /// Function body expression using the 'RETURN' keyword. /// /// Example: diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 7b2652fa4..88943c0dc 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -779,11 +779,9 @@ impl Spanned for ConditionalStatements { ConditionalStatements::Sequence { statements } => { union_spans(statements.iter().map(|s| s.span())) } - ConditionalStatements::BeginEnd { - begin_token: AttachedToken(start), - statements: _, - end_token: AttachedToken(end), - } => union_spans([start.span, end.span].into_iter()), + ConditionalStatements::BeginEnd(bes) => { + union_spans([bes.begin_token.0.span, bes.end_token.0.span].into_iter()) + } } } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index d86d68a20..1f3e353db 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -16,7 +16,9 @@ // under the License. use crate::ast::helpers::attached_token::AttachedToken; -use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement}; +use crate::ast::{ + BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement, +}; use crate::dialect::Dialect; use crate::keywords::{self, Keyword}; use crate::parser::{Parser, ParserError}; @@ -149,11 +151,11 @@ impl MsSqlDialect { start_token: AttachedToken(if_token), condition: Some(condition), then_token: None, - conditional_statements: ConditionalStatements::BeginEnd { + conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements { begin_token: AttachedToken(begin_token), statements, end_token: AttachedToken(end_token), - }, + }), } } else { let stmt = parser.parse_statement()?; @@ -182,11 +184,11 @@ impl MsSqlDialect { start_token: AttachedToken(else_token), condition: None, then_token: None, - conditional_statements: ConditionalStatements::BeginEnd { + conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements { begin_token: AttachedToken(begin_token), statements, end_token: AttachedToken(end_token), - }, + }), }); } else { let stmt = parser.parse_statement()?; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 62e4f6f59..d3a25a368 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4453,9 +4453,17 @@ impl<'a> Parser<'a> { break; } } - values.push(self.parse_statement()?); - self.expect_token(&Token::SemiColon)?; + + let semi_colon_expected = match values.last() { + Some(Statement::If(if_statement)) => if_statement.end_token.is_some(), + Some(_) => true, + None => false, + }; + + if semi_colon_expected { + self.expect_token(&Token::SemiColon)?; + } } Ok(values) } @@ -5168,20 +5176,16 @@ impl<'a> Parser<'a> { }; self.expect_keyword_is(Keyword::AS)?; - self.expect_keyword_is(Keyword::BEGIN)?; - let mut result = self.parse_statements()?; - // note: `parse_statements` will consume the `END` token & produce a Commit statement... - if let Some(Statement::Commit { - chain, - end, - modifier, - }) = result.last() - { - if *chain == false && *end == true && *modifier == None { - result = result[..result.len() - 1].to_vec(); - } - } - let function_body = Some(CreateFunctionBody::MultiStatement(result)); + + let begin_token = self.expect_keyword(Keyword::BEGIN)?; + let statements = self.parse_statement_list(&[Keyword::END])?; + let end_token = self.expect_keyword(Keyword::END)?; + + let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(begin_token), + statements, + end_token: AttachedToken(end_token), + })); Ok(Statement::CreateFunction(CreateFunction { or_alter, diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 65450977f..a09a830d2 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -190,7 +190,7 @@ fn parse_mssql_create_procedure() { #[test] fn parse_create_function() { - let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1 END"; + let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1; END"; assert_eq!( ms().verified_stmt(return_expression_function), sqlparser::ast::Statement::CreateFunction(CreateFunction { @@ -229,11 +229,27 @@ fn parse_create_function() { }, ]), return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::MultiStatement(vec![ - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))), - }), - ])), + function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + } + ))), + statements: vec![Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value( + (number("1")).with_empty_span() + ))), + }),], + end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "END".to_string(), + quote_style: None, + keyword: Keyword::END + } + ))), + })), behavior: None, called_on_null: None, parallel: None, @@ -251,7 +267,7 @@ fn parse_create_function() { AS \ BEGIN \ SET @foo = @foo + 1; \ - RETURN @foo \ + RETURN @foo; \ END\ "; assert_eq!( @@ -292,31 +308,45 @@ fn parse_create_function() { }, ]), return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::MultiStatement(vec![ - Statement::Set(Set::SingleAssignment { - scope: None, - hivevar: false, - variable: ObjectName::from(vec!["@foo".into()]), - values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident { - value: "@foo".to_string(), + function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + } + ))), + statements: vec![ + Statement::Set(Set::SingleAssignment { + scope: None, + hivevar: false, + variable: ObjectName::from(vec!["@foo".into()]), + values: vec![sqlparser::ast::Expr::BinaryOp { + left: Box::new(sqlparser::ast::Expr::Identifier(Ident { + value: "@foo".to_string(), + quote_style: None, + span: Span::empty(), + })), + op: sqlparser::ast::BinaryOperator::Plus, + right: Box::new(Expr::Value(number("1").with_empty_span())), + }], + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { + value: "@foo".into(), quote_style: None, span: Span::empty(), - })), - op: sqlparser::ast::BinaryOperator::Plus, - right: Box::new(Expr::Value( - (Value::Number("1".into(), false)).with_empty_span() - )), - }], - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { - value: "@foo".into(), + }))), + }), + ], + end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "END".to_string(), quote_style: None, - span: Span::empty(), - }))), - }), - ])), + keyword: Keyword::END + } + ))), + })), behavior: None, called_on_null: None, parallel: None, @@ -356,50 +386,76 @@ fn parse_create_function() { }]), args: Some(vec![]), return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::MultiStatement(vec![ - Statement::If(IfStatement { - if_block: ConditionalStatementBlock { - start_token: AttachedToken(TokenWithSpan::wrap( - sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { - value: "IF".to_string(), - quote_style: None, - keyword: Keyword::IF - }) - )), - condition: Some(Expr::BinaryOp { - left: Box::new(Expr::Value( - Value::Number("1".to_string(), false).with_empty_span() + function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + } + ))), + statements: vec![ + Statement::If(IfStatement { + if_block: ConditionalStatementBlock { + start_token: AttachedToken(TokenWithSpan::wrap( + sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { + value: "IF".to_string(), + quote_style: None, + keyword: Keyword::IF + }) )), - op: sqlparser::ast::BinaryOperator::Eq, - right: Box::new(Expr::Value(Value::Number("2".to_string(), false).with_empty_span())), - }), - then_token: None, - conditional_statements: ConditionalStatements::BeginEnd { - begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - }))), - statements: vec![Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))), - })], - end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - }))), + condition: Some(Expr::BinaryOp { + left: Box::new(Expr::Value(number("1").with_empty_span())), + op: sqlparser::ast::BinaryOperator::Eq, + right: Box::new(Expr::Value(number("2").with_empty_span())), + }), + then_token: None, + conditional_statements: ConditionalStatements::BeginEnd( + BeginEndStatements { + begin_token: AttachedToken(TokenWithSpan::wrap( + sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + } + ) + )), + statements: vec![Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value( + (number("1")).with_empty_span() + ))), + })], + end_token: AttachedToken(TokenWithSpan::wrap( + sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "END".to_string(), + quote_style: None, + keyword: Keyword::END + } + ) + )), + } + ), }, - }, - elseif_blocks: vec![], - else_block: None, - end_token: None, - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Value( - (number("0")).with_empty_span() - ))), - }), - ])), + elseif_blocks: vec![], + else_block: None, + end_token: None, + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Value( + (number("0")).with_empty_span() + ))), + }), + ], + end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "END".to_string(), + quote_style: None, + keyword: Keyword::END + } + ))), + })), behavior: None, called_on_null: None, parallel: None, @@ -420,7 +476,7 @@ fn parse_mssql_create_function() { AS \ BEGIN \ SET @foo = @foo + 1; \ - RETURN @foo \ + RETURN @foo; \ END\ "; assert_eq!( @@ -461,31 +517,45 @@ fn parse_mssql_create_function() { }, ]), return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::MultiStatement(vec![ - Statement::Set(Set::SingleAssignment { - scope: None, - hivevar: false, - variable: ObjectName::from(vec!["@foo".into()]), - values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident { - value: "@foo".to_string(), + function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "BEGIN".to_string(), + quote_style: None, + keyword: Keyword::BEGIN + } + ))), + statements: vec![ + Statement::Set(Set::SingleAssignment { + scope: None, + hivevar: false, + variable: ObjectName::from(vec!["@foo".into()]), + values: vec![sqlparser::ast::Expr::BinaryOp { + left: Box::new(sqlparser::ast::Expr::Identifier(Ident { + value: "@foo".to_string(), + quote_style: None, + span: Span::empty(), + })), + op: sqlparser::ast::BinaryOperator::Plus, + right: Box::new(Expr::Value(number("1").with_empty_span())), + }], + }), + Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { + value: "@foo".into(), quote_style: None, span: Span::empty(), - })), - op: sqlparser::ast::BinaryOperator::Plus, - right: Box::new(Expr::Value( - (Value::Number("1".into(), false)).with_empty_span() - )), - }], - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { - value: "@foo".into(), + }))), + }), + ], + end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "END".to_string(), quote_style: None, - span: Span::empty(), - }))), - }), - ])), + keyword: Keyword::END + } + ))), + })), behavior: None, called_on_null: None, parallel: None, From 157aa604276a8540202a3acbad7eb630a3ab91cd Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Tue, 15 Apr 2025 16:44:28 -0400 Subject: [PATCH 05/16] Extract common implementation for function name & parameters --- src/parser/mod.rs | 55 ++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d3a25a368..568618ccc 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5062,22 +5062,7 @@ impl<'a> Parser<'a> { temporary: bool, ) -> Result { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); - let name = self.parse_object_name(false)?; - - let parse_function_param = - |parser: &mut Parser| -> Result { - let name = parser.parse_identifier()?; - let data_type = parser.parse_data_type()?; - Ok(OperateFunctionArg { - mode: None, - name: Some(name), - data_type, - default_expr: None, - }) - }; - self.expect_token(&Token::LParen)?; - let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?; - self.expect_token(&Token::RParen)?; + let (name, args) = self.parse_create_function_name_and_params()?; let return_type = if self.parse_keyword(Keyword::RETURNS) { Some(self.parse_data_type()?) @@ -5152,22 +5137,7 @@ impl<'a> Parser<'a> { or_replace: bool, temporary: bool, ) -> Result { - let name = self.parse_object_name(false)?; - - let parse_function_param = - |parser: &mut Parser| -> Result { - let name = parser.parse_identifier()?; - let data_type = parser.parse_data_type()?; - Ok(OperateFunctionArg { - mode: None, - name: Some(name), - data_type, - default_expr: None, - }) - }; - self.expect_token(&Token::LParen)?; - let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?; - self.expect_token(&Token::RParen)?; + let (name, args) = self.parse_create_function_name_and_params()?; let return_type = if self.parse_keyword(Keyword::RETURNS) { Some(self.parse_data_type()?) @@ -5207,6 +5177,27 @@ impl<'a> Parser<'a> { })) } + fn parse_create_function_name_and_params( + &mut self, + ) -> Result<(ObjectName, Vec), ParserError> { + let name = self.parse_object_name(false)?; + let parse_function_param = + |parser: &mut Parser| -> Result { + let name = parser.parse_identifier()?; + let data_type = parser.parse_data_type()?; + Ok(OperateFunctionArg { + mode: None, + name: Some(name), + data_type, + default_expr: None, + }) + }; + self.expect_token(&Token::LParen)?; + let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?; + self.expect_token(&Token::RParen)?; + Ok((name, args)) + } + fn parse_function_arg(&mut self) -> Result { let mode = if self.parse_keyword(Keyword::IN) { Some(ArgMode::In) From 930eaa86a57519ae95e67c2791ea326fce1dc5cd Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 16:52:59 -0400 Subject: [PATCH 06/16] Support bare `RETURN` without expression & add common test --- src/parser/mod.rs | 14 ++++++++++---- tests/sqlparser_common.rs | 5 +++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 568618ccc..314c784d5 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -15130,10 +15130,16 @@ impl<'a> Parser<'a> { /// Parse [Statement::Return] fn parse_return(&mut self) -> Result { - let expr = self.parse_expr()?; - Ok(Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(expr)), - })) + let current_index = self.index; + match self.parse_expr() { + Ok(expr) => Ok(Statement::Return(ReturnStatement { + value: Some(ReturnStatementValue::Expr(expr)), + })), + Err(_) => { + self.index = current_index; + Ok(Statement::Return(ReturnStatement { value: None })) + } + } } /// Consume the parser and return its underlying token buffer diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index be848a603..e5779dc45 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -15029,3 +15029,8 @@ fn parse_set_time_zone_alias() { _ => unreachable!(), } } + +#[test] +fn parse_return() { + all_dialects().verified_stmt("RETURN"); +} From e5116bc6685051ddd9b2a1156aa9a42a1f23f1fa Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Fri, 18 Apr 2025 15:34:20 -0400 Subject: [PATCH 07/16] Add `RETURN` test with value --- tests/sqlparser_common.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index e5779dc45..51241bbb0 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -15032,5 +15032,8 @@ fn parse_set_time_zone_alias() { #[test] fn parse_return() { - all_dialects().verified_stmt("RETURN"); + let stmt = all_dialects().verified_stmt("RETURN"); + assert_eq!(stmt, Statement::Return(ReturnStatement { value: None })); + + let _ = all_dialects().verified_stmt("RETURN 1"); } From 878046261fe54821f2e8b9624d4ad6d5cab6a7f5 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 17:01:52 -0400 Subject: [PATCH 08/16] Simplify test fixtures --- tests/sqlparser_mssql.rs | 152 +++++++-------------------------------- 1 file changed, 26 insertions(+), 126 deletions(-) diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index a09a830d2..9b01f72a3 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -198,29 +198,17 @@ fn parse_create_function() { or_replace: false, temporary: false, if_not_exists: false, - name: ObjectName::from(vec![Ident { - value: "some_scalar_udf".into(), - quote_style: None, - span: Span::empty(), - }]), + name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), args: Some(vec![ OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@foo".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@foo")), data_type: DataType::Int(None), default_expr: None, }, OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@bar".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@bar")), data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 256, unit: None @@ -230,25 +218,13 @@ fn parse_create_function() { ]), return_type: Some(DataType::Int(None)), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - } - ))), + begin_token: AttachedToken::empty(), statements: vec![Statement::Return(ReturnStatement { value: Some(ReturnStatementValue::Expr(Expr::Value( (number("1")).with_empty_span() ))), }),], - end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - } - ))), + end_token: AttachedToken::empty(), })), behavior: None, called_on_null: None, @@ -277,29 +253,17 @@ fn parse_create_function() { or_replace: false, temporary: false, if_not_exists: false, - name: ObjectName::from(vec![Ident { - value: "some_scalar_udf".into(), - quote_style: None, - span: Span::empty(), - }]), + name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), args: Some(vec![ OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@foo".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@foo")), data_type: DataType::Int(None), default_expr: None, }, OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@bar".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@bar")), data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 256, unit: None @@ -309,43 +273,25 @@ fn parse_create_function() { ]), return_type: Some(DataType::Int(None)), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - } - ))), + begin_token: AttachedToken::empty(), statements: vec![ Statement::Set(Set::SingleAssignment { scope: None, hivevar: false, variable: ObjectName::from(vec!["@foo".into()]), values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident { - value: "@foo".to_string(), - quote_style: None, - span: Span::empty(), - })), + left: Box::new(sqlparser::ast::Expr::Identifier(Ident::new("@foo"))), op: sqlparser::ast::BinaryOperator::Plus, right: Box::new(Expr::Value(number("1").with_empty_span())), }], }), Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { - value: "@foo".into(), - quote_style: None, - span: Span::empty(), - }))), + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident::new( + "@foo" + )))), }), ], - end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - } - ))), + end_token: AttachedToken::empty(), })), behavior: None, called_on_null: None, @@ -379,21 +325,11 @@ fn parse_create_function() { or_replace: false, temporary: false, if_not_exists: false, - name: ObjectName::from(vec![Ident { - value: "some_scalar_udf".into(), - quote_style: None, - span: Span::empty(), - }]), + name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), args: Some(vec![]), return_type: Some(DataType::Int(None)), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - } - ))), + begin_token: AttachedToken::empty(), statements: vec![ Statement::If(IfStatement { if_block: ConditionalStatementBlock { @@ -448,13 +384,7 @@ fn parse_create_function() { ))), }), ], - end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - } - ))), + end_token: AttachedToken::empty(), })), behavior: None, called_on_null: None, @@ -486,29 +416,17 @@ fn parse_mssql_create_function() { or_replace: false, temporary: false, if_not_exists: false, - name: ObjectName::from(vec![Ident { - value: "some_scalar_udf".into(), - quote_style: None, - span: Span::empty(), - }]), + name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), args: Some(vec![ OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@foo".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@foo")), data_type: DataType::Int(None), default_expr: None, }, OperateFunctionArg { mode: None, - name: Some(Ident { - value: "@bar".into(), - quote_style: None, - span: Span::empty(), - }), + name: Some(Ident::new("@bar")), data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { length: 256, unit: None @@ -518,43 +436,25 @@ fn parse_mssql_create_function() { ]), return_type: Some(DataType::Int(None)), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - } - ))), + begin_token: AttachedToken::empty(), statements: vec![ Statement::Set(Set::SingleAssignment { scope: None, hivevar: false, variable: ObjectName::from(vec!["@foo".into()]), values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident { - value: "@foo".to_string(), - quote_style: None, - span: Span::empty(), - })), + left: Box::new(sqlparser::ast::Expr::Identifier(Ident::new("@foo"))), op: sqlparser::ast::BinaryOperator::Plus, right: Box::new(Expr::Value(number("1").with_empty_span())), }], }), Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident { - value: "@foo".into(), - quote_style: None, - span: Span::empty(), - }))), + value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident::new( + "@foo" + )))), }), ], - end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - } - ))), + end_token: AttachedToken::empty(), })), behavior: None, called_on_null: None, From fd104aad2a794c27632211beb297a9b91b1ddc6f Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 17:03:08 -0400 Subject: [PATCH 09/16] Reword documentation comments to be more generalized for all dialects --- src/ast/mod.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7454e8ded..29b52fe29 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2318,7 +2318,14 @@ impl fmt::Display for ConditionalStatements { } } -/// A shared representation of `BEGIN`, multiple statements, and `END` tokens. +/// Represents a list of statements enclosed within `BEGIN` and `END` keywords. +/// Example: +/// ```sql +/// BEGIN +/// SELECT 1; +/// SELECT 2; +/// END +/// ``` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -9273,17 +9280,10 @@ impl fmt::Display for PrintStatement { } } -/// Return (MsSql) -/// -/// for Functions: -/// RETURN scalar_expression -/// -/// See -/// -/// for Triggers: -/// RETURN +/// Represents a `Return` statement. /// -/// See +/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql) +/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql) #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] From a2c3bdc124ca6b00be5d2c46d2d62546c709f434 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 17:03:28 -0400 Subject: [PATCH 10/16] Simplify writing `AsBeginEnd` body --- src/ast/ddl.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 757c5a1d2..c1c113b32 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -2278,8 +2278,7 @@ impl fmt::Display for CreateFunction { write!(f, " AS {function_body}")?; } if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body { - write!(f, " AS")?; - write!(f, " {}", bes)?; + write!(f, " AS {bes}")?; } Ok(()) } From b44befa8a59caa6dae5de0e8d7c9bd7d96792fdd Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 17:19:00 -0400 Subject: [PATCH 11/16] Further simplify test cases to use `verified_stmt` --- tests/sqlparser_mssql.rs | 227 +++------------------------------------ 1 file changed, 16 insertions(+), 211 deletions(-) diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 9b01f72a3..b86e1a7d4 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -23,8 +23,7 @@ mod test_utils; use helpers::attached_token::AttachedToken; -use sqlparser::keywords::Keyword; -use sqlparser::tokenizer::{Location, Span, TokenWithSpan}; +use sqlparser::tokenizer::{Location, Span}; use test_utils::*; use sqlparser::ast::DataType::{Int, Text, Varbinary}; @@ -246,160 +245,22 @@ fn parse_create_function() { RETURN @foo; \ END\ "; - assert_eq!( - ms().verified_stmt(multi_statement_function), - sqlparser::ast::Statement::CreateFunction(CreateFunction { - or_alter: false, - or_replace: false, - temporary: false, - if_not_exists: false, - name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), - args: Some(vec![ - OperateFunctionArg { - mode: None, - name: Some(Ident::new("@foo")), - data_type: DataType::Int(None), - default_expr: None, - }, - OperateFunctionArg { - mode: None, - name: Some(Ident::new("@bar")), - data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { - length: 256, - unit: None - })), - default_expr: None, - }, - ]), - return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken::empty(), - statements: vec![ - Statement::Set(Set::SingleAssignment { - scope: None, - hivevar: false, - variable: ObjectName::from(vec!["@foo".into()]), - values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident::new("@foo"))), - op: sqlparser::ast::BinaryOperator::Plus, - right: Box::new(Expr::Value(number("1").with_empty_span())), - }], - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident::new( - "@foo" - )))), - }), - ], - end_token: AttachedToken::empty(), - })), - behavior: None, - called_on_null: None, - parallel: None, - using: None, - language: None, - determinism_specifier: None, - options: None, - remote_connection: None, - }), - ); + let _ = ms().verified_stmt(multi_statement_function); - let create_function_with_conditional = r#" - CREATE FUNCTION some_scalar_udf() - RETURNS INT - AS - BEGIN - IF 1=2 - BEGIN - RETURN 1; - END - - RETURN 0; - END - "#; - let create_stmt = ms().one_statement_parses_to(create_function_with_conditional, ""); - assert_eq!( - create_stmt, - Statement::CreateFunction(CreateFunction { - or_alter: false, - or_replace: false, - temporary: false, - if_not_exists: false, - name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), - args: Some(vec![]), - return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken::empty(), - statements: vec![ - Statement::If(IfStatement { - if_block: ConditionalStatementBlock { - start_token: AttachedToken(TokenWithSpan::wrap( - sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word { - value: "IF".to_string(), - quote_style: None, - keyword: Keyword::IF - }) - )), - condition: Some(Expr::BinaryOp { - left: Box::new(Expr::Value(number("1").with_empty_span())), - op: sqlparser::ast::BinaryOperator::Eq, - right: Box::new(Expr::Value(number("2").with_empty_span())), - }), - then_token: None, - conditional_statements: ConditionalStatements::BeginEnd( - BeginEndStatements { - begin_token: AttachedToken(TokenWithSpan::wrap( - sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "BEGIN".to_string(), - quote_style: None, - keyword: Keyword::BEGIN - } - ) - )), - statements: vec![Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Value( - (number("1")).with_empty_span() - ))), - })], - end_token: AttachedToken(TokenWithSpan::wrap( - sqlparser::tokenizer::Token::Word( - sqlparser::tokenizer::Word { - value: "END".to_string(), - quote_style: None, - keyword: Keyword::END - } - ) - )), - } - ), - }, - elseif_blocks: vec![], - else_block: None, - end_token: None, - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Value( - (number("0")).with_empty_span() - ))), - }), - ], - end_token: AttachedToken::empty(), - })), - behavior: None, - called_on_null: None, - parallel: None, - using: None, - language: None, - determinism_specifier: None, - options: None, - remote_connection: None, - }) - ); -} + let create_function_with_conditional = "\ + CREATE FUNCTION some_scalar_udf() \ + RETURNS INT \ + AS \ + BEGIN \ + IF 1 = 2 \ + BEGIN \ + RETURN 1; \ + END; \ + RETURN 0; \ + END\ + "; + let _ = ms().verified_stmt(create_function_with_conditional); -#[test] -fn parse_mssql_create_function() { let create_or_alter_function = "\ CREATE OR ALTER FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \ RETURNS INT \ @@ -409,63 +270,7 @@ fn parse_mssql_create_function() { RETURN @foo; \ END\ "; - assert_eq!( - ms().verified_stmt(create_or_alter_function), - sqlparser::ast::Statement::CreateFunction(CreateFunction { - or_alter: true, - or_replace: false, - temporary: false, - if_not_exists: false, - name: ObjectName::from(vec![Ident::new("some_scalar_udf")]), - args: Some(vec![ - OperateFunctionArg { - mode: None, - name: Some(Ident::new("@foo")), - data_type: DataType::Int(None), - default_expr: None, - }, - OperateFunctionArg { - mode: None, - name: Some(Ident::new("@bar")), - data_type: DataType::Varchar(Some(CharacterLength::IntegerLength { - length: 256, - unit: None - })), - default_expr: None, - }, - ]), - return_type: Some(DataType::Int(None)), - function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken::empty(), - statements: vec![ - Statement::Set(Set::SingleAssignment { - scope: None, - hivevar: false, - variable: ObjectName::from(vec!["@foo".into()]), - values: vec![sqlparser::ast::Expr::BinaryOp { - left: Box::new(sqlparser::ast::Expr::Identifier(Ident::new("@foo"))), - op: sqlparser::ast::BinaryOperator::Plus, - right: Box::new(Expr::Value(number("1").with_empty_span())), - }], - }), - Statement::Return(ReturnStatement { - value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident::new( - "@foo" - )))), - }), - ], - end_token: AttachedToken::empty(), - })), - behavior: None, - called_on_null: None, - parallel: None, - using: None, - language: None, - determinism_specifier: None, - options: None, - remote_connection: None, - }), - ); + let _ = ms().verified_stmt(create_or_alter_function); } #[test] From f7993f34bf2f08df0f46b38171f5e0dacca70469 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 17:27:03 -0400 Subject: [PATCH 12/16] Simplify parsing `RETURN` statements --- src/parser/mod.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 314c784d5..fb6095737 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -15130,15 +15130,11 @@ impl<'a> Parser<'a> { /// Parse [Statement::Return] fn parse_return(&mut self) -> Result { - let current_index = self.index; - match self.parse_expr() { - Ok(expr) => Ok(Statement::Return(ReturnStatement { + match self.maybe_parse(|p| p.parse_expr())? { + Some(expr) => Ok(Statement::Return(ReturnStatement { value: Some(ReturnStatementValue::Expr(expr)), })), - Err(_) => { - self.index = current_index; - Ok(Statement::Return(ReturnStatement { value: None })) - } + None => Ok(Statement::Return(ReturnStatement { value: None })), } } From dbe1cc58e864cd4cc34cd9d4d72f4170e725f22c Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Wed, 16 Apr 2025 18:28:44 -0400 Subject: [PATCH 13/16] Implement `Spanned` trait for `BeginEndStatements` --- src/ast/spans.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 88943c0dc..45c1970bc 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -23,8 +23,8 @@ use crate::tokenizer::Span; use super::{ dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken, - CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, - ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy, + BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, + ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, @@ -779,9 +779,7 @@ impl Spanned for ConditionalStatements { ConditionalStatements::Sequence { statements } => { union_spans(statements.iter().map(|s| s.span())) } - ConditionalStatements::BeginEnd(bes) => { - union_spans([bes.begin_token.0.span, bes.end_token.0.span].into_iter()) - } + ConditionalStatements::BeginEnd(bes) => bes.span(), } } } @@ -2280,6 +2278,12 @@ impl Spanned for TableObject { } } +impl Spanned for BeginEndStatements { + fn span(&self) -> Span { + union_spans([self.begin_token.0.span, self.end_token.0.span].into_iter()) + } +} + #[cfg(test)] pub mod tests { use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect}; From 05d611ca5b9c29ae5110cdf10e792a0e0d55c8da Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Fri, 18 Apr 2025 15:31:06 -0400 Subject: [PATCH 14/16] Create span from exhaustive match on struct fields --- src/ast/spans.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 45c1970bc..cf671f004 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -2280,7 +2280,16 @@ impl Spanned for TableObject { impl Spanned for BeginEndStatements { fn span(&self) -> Span { - union_spans([self.begin_token.0.span, self.end_token.0.span].into_iter()) + let BeginEndStatements { + begin_token, + statements, + end_token, + } = self; + union_spans( + core::iter::once(begin_token.0.span) + .chain(statements.iter().map(|i| i.span())) + .chain(core::iter::once(end_token.0.span)), + ) } } From 86781b64d6b3bf4b768807270184ecb97ad8bc87 Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Mon, 21 Apr 2025 14:43:19 -0400 Subject: [PATCH 15/16] Move `IF` statement semi-colon logic into `parse_if_stmt` - no need for special IF logic to bleed over to another function --- src/dialect/mssql.rs | 4 ++++ src/parser/mod.rs | 11 +---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 1f3e353db..31e324f06 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -169,8 +169,10 @@ impl MsSqlDialect { } }; + let mut prior_statement_ended_with_semi_colon = false; while let Token::SemiColon = parser.peek_token_ref().token { parser.advance_token(); + prior_statement_ended_with_semi_colon = true; } let mut else_block = None; @@ -201,6 +203,8 @@ impl MsSqlDialect { }, }); } + } else if prior_statement_ended_with_semi_colon { + parser.prev_token(); } Ok(Statement::If(IfStatement { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index fb6095737..fabd887bb 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4454,16 +4454,7 @@ impl<'a> Parser<'a> { } } values.push(self.parse_statement()?); - - let semi_colon_expected = match values.last() { - Some(Statement::If(if_statement)) => if_statement.end_token.is_some(), - Some(_) => true, - None => false, - }; - - if semi_colon_expected { - self.expect_token(&Token::SemiColon)?; - } + self.expect_token(&Token::SemiColon)?; } Ok(values) } From ca284133e816e58b85185f77395e1fb702f4a66b Mon Sep 17 00:00:00 2001 From: Andrew Harper Date: Mon, 21 Apr 2025 14:48:51 -0400 Subject: [PATCH 16/16] Simplify parsing `RETURNS` --- src/parser/mod.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index fabd887bb..4ce6bbcd8 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5130,11 +5130,8 @@ impl<'a> Parser<'a> { ) -> Result { let (name, args) = self.parse_create_function_name_and_params()?; - let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) - } else { - return parser_err!("Expected RETURNS keyword", self.peek_token().span.start); - }; + self.expect_keyword(Keyword::RETURNS)?; + let return_type = Some(self.parse_data_type()?); self.expect_keyword_is(Keyword::AS)?;