diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index b945587c6..9bdfd195c 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -22,6 +22,10 @@ impl Dialect for BigQueryDialect { ch == '`' } + fn supports_projection_trailing_commas(&self) -> bool { + true + } + fn is_identifier_start(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' } diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index e141f941f..c6edeac14 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -18,6 +18,10 @@ pub struct DuckDbDialect; // In most cases the redshift dialect is identical to [`PostgresSqlDialect`]. impl Dialect for DuckDbDialect { + fn supports_trailing_commas(&self) -> bool { + true + } + fn is_identifier_start(&self, ch: char) -> bool { ch.is_alphabetic() || ch == '_' } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index a04390570..f4b9359a8 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -219,6 +219,14 @@ pub trait Dialect: Debug + Any { // return None to fall back to the default behavior None } + /// Does the dialect support trailing commas around the query? + fn supports_trailing_commas(&self) -> bool { + false + } + /// Does the dialect support trailing commas in the projection list? + fn supports_projection_trailing_commas(&self) -> bool { + self.supports_trailing_commas() + } /// Dialect-specific infix parser override fn parse_infix( &self, diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 3ec84f602..cf55166c0 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -38,6 +38,10 @@ impl Dialect for SnowflakeDialect { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' } + fn supports_projection_trailing_commas(&self) -> bool { + true + } + fn is_identifier_part(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5f6f28088..c07577f92 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -305,7 +305,7 @@ impl<'a> Parser<'a> { state: ParserState::Normal, dialect, recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), - options: ParserOptions::default(), + options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()), } } @@ -3073,7 +3073,7 @@ impl<'a> Parser<'a> { // This pattern could be captured better with RAII type semantics, but it's quite a bit of // code to add for just one case, so we'll just do it manually here. let old_value = self.options.trailing_commas; - self.options.trailing_commas |= dialect_of!(self is BigQueryDialect | SnowflakeDialect); + self.options.trailing_commas |= self.dialect.supports_projection_trailing_commas(); let ret = self.parse_comma_separated(|p| p.parse_select_item()); self.options.trailing_commas = old_value; @@ -5107,12 +5107,17 @@ impl<'a> Parser<'a> { } else { return self.expected("column name or constraint definition", self.peek_token()); } + let comma = self.consume_token(&Token::Comma); - if self.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard - break; - } else if !comma { + let rparen = self.peek_token().token == Token::RParen; + + if !comma && !rparen { return self.expected("',' or ')' after column definition", self.peek_token()); + }; + + if rparen && (!comma || self.options.trailing_commas) { + let _ = self.consume_token(&Token::RParen); + break; } } @@ -8955,6 +8960,9 @@ impl<'a> Parser<'a> { with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES), } } else { + let old_value = self.options.trailing_commas; + self.options.trailing_commas = false; + let (actions, err): (Vec<_>, Vec<_>) = self .parse_comma_separated(Parser::parse_grant_permission)? .into_iter() @@ -8978,6 +8986,8 @@ impl<'a> Parser<'a> { }) .partition(Result::is_ok); + self.options.trailing_commas = old_value; + if !err.is_empty() { let errors: Vec = err.into_iter().filter_map(|x| x.err()).collect(); return Err(ParserError::ParserError(format!( @@ -9463,6 +9473,12 @@ impl<'a> Parser<'a> { Expr::Wildcard => Ok(SelectItem::Wildcard( self.parse_wildcard_additional_options()?, )), + Expr::Identifier(v) if v.value.to_lowercase() == "from" => { + parser_err!( + format!("Expected an expression, found: {}", v), + self.peek_token().location + ) + } expr => self .parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) .map(|alias| match alias { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 9baf380a0..2635b29d6 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -3548,8 +3548,13 @@ fn parse_create_table_clone() { #[test] fn parse_create_table_trailing_comma() { - let sql = "CREATE TABLE foo (bar int,)"; - all_dialects().one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); + let dialect = TestedDialects { + dialects: vec![Box::new(DuckDbDialect {})], + options: None, + }; + + let sql = "CREATE TABLE foo (bar int,);"; + dialect.one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); } #[test] @@ -4414,7 +4419,7 @@ fn parse_window_clause() { ORDER BY C3"; verified_only_select(sql); - let sql = "SELECT from mytable WINDOW window1 AS window2"; + let sql = "SELECT * from mytable WINDOW window1 AS window2"; let dialects = all_dialects_except(|d| d.is::() || d.is::()); let res = dialects.parse_sql_statements(sql); assert_eq!( @@ -8746,9 +8751,11 @@ fn parse_non_latin_identifiers() { #[test] fn parse_trailing_comma() { + // At the moment, Duck DB is the only dialect that allows + // trailing commas anywhere in the query let trailing_commas = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: Some(ParserOptions::new().with_trailing_commas(true)), + dialects: vec![Box::new(DuckDbDialect {})], + options: None, }; trailing_commas.one_statement_parses_to( @@ -8766,11 +8773,74 @@ fn parse_trailing_comma() { "SELECT DISTINCT ON (album_id) name FROM track", ); + trailing_commas.one_statement_parses_to( + "CREATE TABLE employees (name text, age int,)", + "CREATE TABLE employees (name TEXT, age INT)", + ); + trailing_commas.verified_stmt("SELECT album_id, name FROM track"); trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); + + // doesn't allow any trailing commas + let trailing_commas = TestedDialects { + dialects: vec![Box::new(GenericDialect {})], + options: None, + }; + + assert_eq!( + trailing_commas + .parse_sql_statements("SELECT name, age, from employees;") + .unwrap_err(), + ParserError::ParserError("Expected an expression, found: from".to_string()) + ); + + assert_eq!( + trailing_commas + .parse_sql_statements("CREATE TABLE employees (name text, age int,)") + .unwrap_err(), + ParserError::ParserError( + "Expected column name or constraint definition, found: )".to_string() + ) + ); +} + +#[test] +fn parse_projection_trailing_comma() { + // Some dialects allow trailing commas only in the projection + let trailing_commas = TestedDialects { + dialects: vec![Box::new(SnowflakeDialect {}), Box::new(BigQueryDialect {})], + options: None, + }; + + trailing_commas.one_statement_parses_to( + "SELECT album_id, name, FROM track", + "SELECT album_id, name FROM track", + ); + + trailing_commas.verified_stmt("SELECT album_id, name FROM track"); + + trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); + + trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); + + assert_eq!( + trailing_commas + .parse_sql_statements("SELECT * FROM track ORDER BY milliseconds,") + .unwrap_err(), + ParserError::ParserError("Expected an expression:, found: EOF".to_string()) + ); + + assert_eq!( + trailing_commas + .parse_sql_statements("CREATE TABLE employees (name text, age int,)") + .unwrap_err(), + ParserError::ParserError( + "Expected column name or constraint definition, found: )".to_string() + ), + ); } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 5701f6b2b..9c435e04e 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3561,7 +3561,7 @@ fn parse_create_table_with_alias() { int2_col INT2, float8_col FLOAT8, float4_col FLOAT4, - bool_col BOOL, + bool_col BOOL );"; match pg_and_generic().one_statement_parses_to(sql, "") { Statement::CreateTable {