diff --git a/src/ast/mod.rs b/src/ast/mod.rs index b60ade78b..29020550e 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -4097,6 +4097,12 @@ pub enum Statement { /// /// See [ReturnStatement] Return(ReturnStatement), + /// Go (MsSql) + /// + /// GO is not a Transact-SQL statement; it is a command recognized by various tools as a batch delimiter + /// + /// See: + Go(GoStatement), } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -5791,6 +5797,7 @@ impl fmt::Display for Statement { Ok(()) } Statement::Print(s) => write!(f, "{s}"), + Statement::Go(s) => write!(f, "{s}"), Statement::Return(r) => write!(f, "{r}"), Statement::List(command) => write!(f, "LIST {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"), @@ -9315,6 +9322,26 @@ pub enum ReturnStatementValue { Expr(Expr), } +/// Represents a `GO` statement. +/// +/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct GoStatement { + pub count: Option, +} + +impl Display for GoStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(count) = self.count { + write!(f, "GO {count}") + } else { + write!(f, "GO") + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 93de5fff2..69c177b29 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -522,6 +522,7 @@ impl Spanned for Statement { Statement::RaisError { .. } => Span::empty(), Statement::Print { .. } => Span::empty(), Statement::Return { .. } => Span::empty(), + Statement::Go { .. } => Span::empty(), Statement::List(..) | Statement::Remove(..) => Span::empty(), } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 31e324f06..3a0861f21 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -118,7 +118,13 @@ impl Dialect for MsSqlDialect { true } - fn is_column_alias(&self, kw: &Keyword, _parser: &mut Parser) -> bool { + fn is_column_alias(&self, kw: &Keyword, parser: &mut Parser) -> bool { + // if we find maybe whitespace then a newline looking backward, then `GO` ISN'T a column alias + // if we can't find a newline then we assume that `GO` IS a column alias + if kw == &Keyword::GO && parser.prev_only_whitespace_until_newline() { + return false; + } + !keywords::RESERVED_FOR_COLUMN_ALIAS.contains(kw) && !RESERVED_FOR_COLUMN_ALIAS.contains(kw) } diff --git a/src/keywords.rs b/src/keywords.rs index 4eaad7ed2..bbe4fd68c 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -393,6 +393,7 @@ define_keywords!( GIN, GIST, GLOBAL, + GO, GRANT, GRANTED, GRANTS, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index fe81b5999..e33dfa2bf 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -475,6 +475,10 @@ impl<'a> Parser<'a> { if expecting_statement_delimiter && word.keyword == Keyword::END { break; } + + if expecting_statement_delimiter && word.keyword == Keyword::GO { + expecting_statement_delimiter = false; + } } _ => {} } @@ -484,8 +488,9 @@ impl<'a> Parser<'a> { } let statement = self.parse_statement()?; + // Treat batch delimiter as an end of statement, so no additional statement delimiter expected here + expecting_statement_delimiter = !matches!(statement, Statement::Go(_)); stmts.push(statement); - expecting_statement_delimiter = true; } Ok(stmts) } @@ -613,6 +618,10 @@ impl<'a> Parser<'a> { Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(), Keyword::PRINT => self.parse_print(), Keyword::RETURN => self.parse_return(), + Keyword::GO => { + self.prev_token(); + self.parse_go() + } _ => self.expected("an SQL statement", next_token), }, Token::LParen => { @@ -3934,6 +3943,17 @@ impl<'a> Parser<'a> { }) } + /// Return nth previous token, possibly whitespace + /// (or [`Token::EOF`] when before the beginning of the stream). + pub(crate) fn peek_prev_nth_token_no_skip_ref(&self, n: usize) -> &TokenWithSpan { + // 0 = next token, -1 = current token, -2 = previous token + let peek_index = self.index.saturating_sub(1).saturating_sub(n); + if peek_index == 0 { + return &EOF_TOKEN; + } + self.tokens.get(peek_index).unwrap_or(&EOF_TOKEN) + } + /// Return true if the next tokens exactly `expected` /// /// Does not advance the current token. @@ -4050,6 +4070,29 @@ impl<'a> Parser<'a> { ) } + /// Look backwards in the token stream and expect that there was only whitespace tokens until the previous newline or beginning of string + pub(crate) fn prev_only_whitespace_until_newline(&mut self) -> bool { + let mut look_back_count = 1; + loop { + let prev_token = self.peek_prev_nth_token_no_skip_ref(look_back_count); + match prev_token.token { + Token::EOF => break true, + Token::Whitespace(ref w) => match w { + Whitespace::Newline => break true, + // special consideration required for single line comments since that string includes the newline + Whitespace::SingleLineComment { comment, prefix: _ } => { + if comment.ends_with('\n') { + break true; + } + look_back_count += 1; + } + _ => look_back_count += 1, + }, + _ => break false, + }; + } + } + /// If the current token is the `expected` keyword, consume it and returns /// true. Otherwise, no tokens are consumed and returns false. #[must_use] @@ -15225,6 +15268,71 @@ impl<'a> Parser<'a> { } } + /// Parse [Statement::Go] + fn parse_go(&mut self) -> Result { + self.expect_keyword_is(Keyword::GO)?; + + // disambiguate between GO as batch delimiter & GO as identifier (etc) + // compare: + // ```sql + // select 1 go + // ``` + // vs + // ```sql + // select 1 + // go + // ``` + if !self.prev_only_whitespace_until_newline() { + parser_err!( + "GO may only be preceded by whitespace on a line", + self.peek_token().span.start + )?; + } + + let count = loop { + // using this peek function because we want to halt this statement parsing upon newline + let next_token = self.peek_token_no_skip(); + match next_token.token { + Token::EOF => break None::, + Token::Whitespace(ref w) => match w { + Whitespace::Newline => break None, + _ => _ = self.next_token_no_skip(), + }, + Token::Number(s, _) => { + let value = Some(Self::parse::(s, next_token.span.start)?); + self.advance_token(); + break value; + } + _ => self.expected("literal int or newline", next_token)?, + }; + }; + + loop { + let next_token = self.peek_token_no_skip(); + match next_token.token { + Token::EOF => break, + Token::Whitespace(ref w) => match w { + Whitespace::Newline => break, + Whitespace::SingleLineComment { comment, prefix: _ } => { + if comment.ends_with('\n') { + break; + } + _ = self.next_token_no_skip(); + } + _ => _ = self.next_token_no_skip(), + }, + _ => { + parser_err!( + "GO must be followed by a newline or EOF", + self.peek_token().span.start + )?; + } + }; + } + + Ok(Statement::Go(GoStatement { count })) + } + /// Consume the parser and return its underlying token buffer pub fn into_tokens(self) -> Vec { self.tokens @@ -15455,6 +15563,31 @@ mod tests { }) } + #[test] + fn test_peek_prev_nth_token_no_skip_ref() { + all_dialects().run_parser_method( + "SELECT 1;\n-- a comment\nRAISERROR('test', 16, 0);", + |parser| { + parser.index = 1; + assert_eq!(parser.peek_prev_nth_token_no_skip_ref(0), &Token::EOF); + assert_eq!(parser.index, 1); + parser.index = 7; + assert_eq!( + parser.token_at(parser.index - 1).token, + Token::Word(Word { + value: "RAISERROR".to_string(), + quote_style: None, + keyword: Keyword::RAISERROR, + }) + ); + assert_eq!( + parser.peek_prev_nth_token_no_skip_ref(2), + &Token::Whitespace(Whitespace::Newline) + ); + }, + ); + } + #[cfg(test)] mod test_parse_data_type { use crate::ast::{ diff --git a/src/test_utils.rs b/src/test_utils.rs index 6270ac42b..bfd6d3d25 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -151,6 +151,8 @@ impl TestedDialects { /// /// 2. re-serializing the result of parsing `sql` produces the same /// `canonical` sql string + /// + /// For multiple statements, use [`statements_parse_to`]. pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement { let mut statements = self.parse_sql_statements(sql).expect(sql); assert_eq!(statements.len(), 1); @@ -166,6 +168,32 @@ impl TestedDialects { only_statement } + /// The same as [`one_statement_parses_to`] but it works for a multiple statements + pub fn statements_parse_to( + &self, + sql: &str, + statement_count: usize, + canonical: &str, + ) -> Vec { + let statements = self.parse_sql_statements(sql).expect(sql); + assert_eq!(statements.len(), statement_count); + + if !canonical.is_empty() && sql != canonical { + assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements); + } else { + assert_eq!( + sql, + statements + .iter() + .map(|s| s.to_string()) + .collect::>() + .join("; ") + ); + } + + statements + } + /// Ensures that `sql` parses as an [`Expr`], and that /// re-serializing the parse result produces canonical pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr { diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b86e1a7d4..7730695b8 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -23,7 +23,8 @@ mod test_utils; use helpers::attached_token::AttachedToken; -use sqlparser::tokenizer::{Location, Span}; +use sqlparser::keywords::Keyword; +use sqlparser::tokenizer::{Location, Span, TokenWithSpan}; use test_utils::*; use sqlparser::ast::DataType::{Int, Text, Varbinary}; @@ -2156,3 +2157,182 @@ fn parse_print() { let _ = ms().verified_stmt("PRINT N'Hello, ⛄️!'"); let _ = ms().verified_stmt("PRINT @my_variable"); } + +#[test] +fn parse_mssql_go_keyword() { + let single_go_keyword = "USE some_database;\nGO"; + let stmts = ms().statements_parse_to(single_go_keyword, 2, "USE some_database\nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let go_with_count = "SELECT 1;\nGO 5"; + let stmts = ms().statements_parse_to(go_with_count, 2, "SELECT 1\nGO 5"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: Some(5) })); + + let go_statement_delimiter = "SELECT 1\nGO"; + let stmts = ms().statements_parse_to(go_statement_delimiter, 2, "SELECT 1; \nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let bare_go = "GO"; + let stmt = ms().one_statement_parses_to(bare_go, "GO"); + assert_eq!(stmt, Statement::Go(GoStatement { count: None })); + + let go_then_statements = "/* whitespace */ GO\nRAISERROR('This is a test', 16, 1);"; + let stmts = ms().statements_parse_to( + go_then_statements, + 2, + "GO\nRAISERROR('This is a test', 16, 1)", + ); + assert_eq!(stmts[0], Statement::Go(GoStatement { count: None })); + assert_eq!( + stmts[1], + Statement::RaisError { + message: Box::new(Expr::Value( + (Value::SingleQuotedString("This is a test".to_string())).with_empty_span() + )), + severity: Box::new(Expr::Value(number("16").with_empty_span())), + state: Box::new(Expr::Value(number("1").with_empty_span())), + arguments: vec![], + options: vec![], + } + ); + + let multiple_gos = "SELECT 1;\nGO 5\nSELECT 2;\n GO"; + let stmts = ms().statements_parse_to(multiple_gos, 4, "SELECT 1\nGO 5\nSELECT 2\nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: Some(5) })); + assert_eq!(stmts[3], Statement::Go(GoStatement { count: None })); + + let single_line_comment_preceding_go = "USE some_database; -- okay\nGO"; + let stmts = + ms().statements_parse_to(single_line_comment_preceding_go, 2, "USE some_database\nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let multi_line_comment_preceding_go = "USE some_database; /* okay */\nGO"; + let stmts = + ms().statements_parse_to(multi_line_comment_preceding_go, 2, "USE some_database\nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let single_line_comment_following_go = "USE some_database;\nGO -- okay"; + let stmts = + ms().statements_parse_to(single_line_comment_following_go, 2, "USE some_database\nGO"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let multi_line_comment_following = "USE some_database;\nGO/* okay */42"; + let stmts = + ms().statements_parse_to(multi_line_comment_following, 2, "USE some_database\nGO 42"); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: Some(42) })); + + let cte_following_go = + "USE some_database;\nGO\n;WITH cte AS (\nSELECT 1 x\n)\nSELECT * FROM cte;"; + let stmts = ms().parse_sql_statements(cte_following_go).unwrap(); + assert_eq!(stmts.len(), 3); + assert_eq!(stmts[1], Statement::Go(GoStatement { count: None })); + + let actually_column_alias = "SELECT NULL GO"; + let stmt = ms().one_statement_parses_to(actually_column_alias, "SELECT NULL AS GO"); + match &stmt { + Statement::Query(query) => { + let select = query.body.as_select().unwrap(); + assert_eq!( + only(select.clone().projection), + SelectItem::ExprWithAlias { + expr: Expr::Value(Value::Null.with_empty_span()), + alias: Ident::new("GO"), + } + ); + } + _ => panic!("Expected Query statement"), + } + + let invalid_go_position = "SELECT 1; GO"; + let err = ms().parse_sql_statements(invalid_go_position); + assert_eq!( + err.unwrap_err().to_string(), + "sql parser error: GO may only be preceded by whitespace on a line" + ); + + let invalid_go_count = "SELECT 1\nGO x"; + let err = ms().parse_sql_statements(invalid_go_count); + assert_eq!( + err.unwrap_err().to_string(), + "sql parser error: Expected: literal int or newline, found: x" + ); + + let invalid_go_delimiter = "SELECT 1\nGO;"; + let err = ms().parse_sql_statements(invalid_go_delimiter); + assert_eq!( + err.unwrap_err().to_string(), + "sql parser error: Expected: literal int or newline, found: ;" + ); +} + +#[test] +fn test_mssql_if_and_go() { + let sql = r#" + IF 1 = 2 + SELECT 3; + GO + "#; + let statements = ms().parse_sql_statements(sql).unwrap(); + assert_eq!(2, statements.len()); + assert_eq!( + statements[0], + Statement::If(IfStatement { + if_block: ConditionalStatementBlock { + start_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word( + sqlparser::tokenizer::Word { + value: "IF".to_string(), + quote_style: None, + keyword: Keyword::IF + } + ))), + condition: Some(Expr::BinaryOp { + left: Box::new(Expr::Value((number("1")).with_empty_span())), + op: sqlparser::ast::BinaryOperator::Eq, + right: Box::new(Expr::Value((number("2")).with_empty_span())), + }), + then_token: None, + conditional_statements: ConditionalStatements::Sequence { + statements: vec![Statement::Query(Box::new(Query { + with: None, + limit_clause: None, + fetch: None, + locks: vec![], + for_clause: None, + order_by: None, + settings: None, + format_clause: None, + body: Box::new(SetExpr::Select(Box::new(Select { + select_token: AttachedToken::empty(), + distinct: None, + top: None, + top_before_distinct: false, + projection: vec![SelectItem::UnnamedExpr(Expr::Value( + (number("3")).with_empty_span() + ))], + into: None, + from: vec![], + lateral_views: vec![], + prewhere: None, + selection: None, + group_by: GroupByExpr::Expressions(vec![], vec![]), + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + named_window: vec![], + window_before_qualify: false, + qualify: None, + value_table_mode: None, + connect_by: None, + flavor: SelectFlavor::Standard, + }))), + }))], + }, + }, + elseif_blocks: vec![], + else_block: None, + end_token: None, + }) + ); + assert_eq!(statements[1], Statement::Go(GoStatement { count: None })); +}