From 945e4acfc184b4160990755d9e51d61df7f806cf Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 5 Mar 2025 01:08:44 +0200 Subject: [PATCH 01/11] minor refactoring of parse_set --- src/dialect/hive.rs | 4 + src/dialect/mod.rs | 16 ++++ src/dialect/mssql.rs | 5 ++ src/dialect/mysql.rs | 4 + src/keywords.rs | 2 + src/parser/mod.rs | 166 +++++++++++++++++++----------------- t.sql | 1 + tests/sqlparser_common.rs | 43 ++++++---- tests/sqlparser_hive.rs | 2 +- tests/sqlparser_postgres.rs | 14 +++ 10 files changed, 159 insertions(+), 98 deletions(-) create mode 100644 t.sql diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 3e15d395b..4e838e27f 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -44,6 +44,10 @@ impl Dialect for HiveDialect { true } + fn supports_set_multiple_values(&self) -> bool { + true + } + fn supports_numeric_prefix(&self) -> bool { true } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index aeb097cfd..5322de732 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -352,6 +352,12 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports multiple values in a SET expression + /// e.g. `SET OFFSETS SELECT, FROM, ORDER, TABLE, PROCEDURE, EXECUTE ON` + fn supports_set_multiple_values(&self) -> bool { + false + } + /// Returns true if the dialects supports specifying null treatment /// as part of a window function's parameter list as opposed /// to after the parameter list. @@ -399,6 +405,16 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports multiple `SET` statements + /// in a single statement. + /// + /// ```sql + /// SET variable = expression [, variable = expression]; + /// ``` + fn supports_comma_separated_set_assignments(&self) -> bool { + false + } + /// Returns true if the dialect supports an `EXCEPT` clause following a /// wildcard in a select list. /// diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index aeed1eb79..ba4d78826 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -58,6 +58,10 @@ impl Dialect for MsSqlDialect { true } + fn supports_set_multiple_values(&self) -> bool { + true + } + fn supports_try_convert(&self) -> bool { true } @@ -82,6 +86,7 @@ impl Dialect for MsSqlDialect { fn supports_start_transaction_modifier(&self) -> bool { true } + fn supports_end_transaction_modifier(&self) -> bool { true } diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 0bdfc9bf3..2077ea195 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -141,6 +141,10 @@ impl Dialect for MySqlDialect { fn supports_set_names(&self) -> bool { true } + + fn supports_comma_separated_set_assignments(&self) -> bool { + true + } } /// `LOCK TABLES` diff --git a/src/keywords.rs b/src/keywords.rs index bda817df9..f40ec6d36 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -173,6 +173,7 @@ define_keywords!( CHANNEL, CHAR, CHARACTER, + CHARACTERISTIC, CHARACTERS, CHARACTER_LENGTH, CHARSET, @@ -557,6 +558,7 @@ define_keywords!( MULTISET, MUTATION, NAME, + NAMES, NANOSECOND, NANOSECONDS, NATIONAL, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b34415388..4e9cd3d43 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10961,6 +10961,37 @@ impl<'a> Parser<'a> { }) } + fn parse_set_values( + &mut self, + parenthesized_assignment: bool, + ) -> Result, ParserError> { + let mut values = vec![]; + + if parenthesized_assignment { + self.expect_token(&Token::LParen)?; + } + + loop { + let value = if let Some(expr) = self.try_parse_expr_sub_query()? { + expr + } else if let Ok(expr) = self.parse_expr() { + expr + } else { + self.expected("variable value", self.peek_token())? + }; + + values.push(value); + if self.consume_token(&Token::Comma) { + continue; + } + + if parenthesized_assignment { + self.expect_token(&Token::RParen)?; + } + return Ok(values); + } + } + pub fn parse_set(&mut self) -> Result { let modifier = self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]); @@ -10989,99 +11020,76 @@ impl<'a> Parser<'a> { OneOrManyWithParens::One(self.parse_object_name(false)?) }; - let names = matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES")); - - if names && self.dialect.supports_set_names() { - if self.parse_keyword(Keyword::DEFAULT) { - return Ok(Statement::SetNamesDefault {}); - } - - let charset_name = self.parse_identifier()?; - let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { - Some(self.parse_literal_string()?) - } else { - None - }; - - return Ok(Statement::SetNames { - charset_name, - collation_name, + if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { + let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_)); + let values = self.parse_set_values(parenthesized_assignment); + + return Ok(Statement::SetVariable { + local: modifier == Some(Keyword::LOCAL), + hivevar: modifier == Some(Keyword::HIVEVAR), + variables, + value: values, }); } - let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_)); - - if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { - if parenthesized_assignment { - self.expect_token(&Token::LParen)?; - } + let OneOrManyWithParens::One(variable) = variables else { + return self.expected("set variable", self.peek_token()); + }; - let mut values = vec![]; - loop { - let value = if let Some(expr) = self.try_parse_expr_sub_query()? { - expr - } else if let Ok(expr) = self.parse_expr() { - expr + match variable.to_string().to_ascii_uppercase().as_str() { + "NAMES" if self.dialect.supports_set_names() => { + if self.parse_keyword(Keyword::DEFAULT) { + return Ok(Statement::SetNamesDefault {}); + } + let charset_name = self.parse_identifier()?; + let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { + Some(self.parse_literal_string()?) } else { - self.expected("variable value", self.peek_token())? + None }; - values.push(value); - if self.consume_token(&Token::Comma) { - continue; - } - - if parenthesized_assignment { - self.expect_token(&Token::RParen)?; + return Ok(Statement::SetNames { + charset_name, + collation_name, + }); + } + "TIMEZONE" => match self.parse_expr() { + Ok(expr) => { + return Ok(Statement::SetTimeZone { + local: modifier == Some(Keyword::LOCAL), + value: expr, + }) } - return Ok(Statement::SetVariable { - local: modifier == Some(Keyword::LOCAL), - hivevar: Some(Keyword::HIVEVAR) == modifier, - variables, - value: values, + _ => return self.expected("timezone value", self.peek_token()), + }, + "CHARACTERISTICS" => { + self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; + return Ok(Statement::SetTransaction { + modes: self.parse_transaction_modes()?, + snapshot: None, + session: true, }); } - } - - let OneOrManyWithParens::One(variable) = variables else { - return self.expected("set variable", self.peek_token()); - }; - - if variable.to_string().eq_ignore_ascii_case("TIMEZONE") { - // for some db (e.g. postgresql), SET TIME ZONE is an alias for SET TIMEZONE [TO|=] - match self.parse_expr() { - Ok(expr) => Ok(Statement::SetTimeZone { - local: modifier == Some(Keyword::LOCAL), - value: expr, - }), - _ => self.expected("timezone value", self.peek_token())?, - } - } else if variable.to_string() == "CHARACTERISTICS" { - self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; - Ok(Statement::SetTransaction { - modes: self.parse_transaction_modes()?, - snapshot: None, - session: true, - }) - } else if variable.to_string() == "TRANSACTION" && modifier.is_none() { - if self.parse_keyword(Keyword::SNAPSHOT) { - let snapshot_id = self.parse_value()?.value; + "TRANSACTION" if modifier.is_none() => { + if self.parse_keyword(Keyword::SNAPSHOT) { + let snapshot_id = self.parse_value()?.value; + return Ok(Statement::SetTransaction { + modes: vec![], + snapshot: Some(snapshot_id), + session: false, + }); + } return Ok(Statement::SetTransaction { - modes: vec![], - snapshot: Some(snapshot_id), + modes: self.parse_transaction_modes()?, + snapshot: None, session: false, }); } - Ok(Statement::SetTransaction { - modes: self.parse_transaction_modes()?, - snapshot: None, - session: false, - }) - } else if self.dialect.supports_set_stmt_without_operator() { - self.prev_token(); - self.parse_set_session_params() - } else { - self.expected("equals sign or TO", self.peek_token()) + _ if self.dialect.supports_set_stmt_without_operator() => { + self.prev_token(); + return self.parse_set_session_params(); + } + _ => return self.expected("equals sign or TO", self.peek_token()), } } diff --git a/t.sql b/t.sql new file mode 100644 index 000000000..a822673c4 --- /dev/null +++ b/t.sql @@ -0,0 +1 @@ +SET TIME ZONE TO 'UTC' diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a8ccd70a7..218185f03 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8618,10 +8618,10 @@ fn parse_set_variable() { "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", ), - ( - "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), SELECT 33 FROM tbl3)", - "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), (SELECT 33 FROM tbl3))", - ), + // ( + // "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), SELECT 33 FROM tbl3)", + // "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), (SELECT 33 FROM tbl3))", + // ), ] { multi_variable_dialects.one_statement_parses_to(sql, canonical); } @@ -8728,20 +8728,6 @@ fn parse_set_time_zone() { one_statement_parses_to("SET TIME ZONE TO 'UTC'", "SET TIMEZONE = 'UTC'"); } -#[test] -fn parse_set_time_zone_alias() { - match verified_stmt("SET TIME ZONE 'UTC'") { - Statement::SetTimeZone { local, value } => { - assert!(!local); - assert_eq!( - value, - Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span()) - ); - } - _ => unreachable!(), - } -} - #[test] fn parse_commit() { match verified_stmt("COMMIT") { @@ -14654,3 +14640,24 @@ fn parse_set_names() { dialects.verified_stmt("SET NAMES 'utf8'"); dialects.verified_stmt("SET NAMES UTF8 COLLATE bogus"); } + +#[test] +fn parse_multiple_set_statements() -> Result<(), ParserError> { + let dialects = all_dialects_where(|d| d.supports_comma_separated_set_assignments()); + let stmt = dialects.parse_sql_statements("SET @a = 1, b = 2")?; + + let stmt = stmt[0].clone(); + + assert!(matches!(stmt, Statement::SetVariable { .. })); + match stmt { + Statement::SetVariable { + variables, value, .. + } => { + assert_eq!(variables.len(), 2); + assert_eq!(value.len(), 2); + } + _ => assert!(false, "Expected SetVariable with 2 variables and 2 values"), + }; + + Ok(()) +} diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index d7f3c014b..27fa4b405 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -92,7 +92,7 @@ fn parse_msck() { } #[test] -fn parse_set() { +fn parse_set_hivevar() { let set = "SET HIVEVAR:name = a, b, c_d"; hive().verified_stmt(set); } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 0dfcc24ea..659ed9b01 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -5638,6 +5638,20 @@ fn parse_create_type_as_enum() { } } +#[test] +fn parse_set_time_zone_alias() { + match pg().verified_stmt("SET TIME ZONE 'UTC'") { + Statement::SetTimeZone { local, value } => { + assert!(!local); + assert_eq!( + value, + Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span()) + ); + } + _ => unreachable!(), + } +} + #[test] fn parse_alter_type() { struct TestCase { From 421381ccf3df9a0ea9cbbb0147b961a787cdb938 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Thu, 6 Mar 2025 22:44:04 +0200 Subject: [PATCH 02/11] setting multiple variables at once (MySQL) --- src/dialect/hive.rs | 4 --- src/dialect/mod.rs | 6 ---- src/dialect/mssql.rs | 4 --- src/keywords.rs | 2 -- src/parser/mod.rs | 59 +++++++++++++++++++++++++++++++++++---- t.sql | 1 - tests/sqlparser_common.rs | 10 +++---- 7 files changed, 59 insertions(+), 27 deletions(-) delete mode 100644 t.sql diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 4e838e27f..3e15d395b 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -44,10 +44,6 @@ impl Dialect for HiveDialect { true } - fn supports_set_multiple_values(&self) -> bool { - true - } - fn supports_numeric_prefix(&self) -> bool { true } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 5322de732..8d4557e2f 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -352,12 +352,6 @@ pub trait Dialect: Debug + Any { false } - /// Returns true if the dialect supports multiple values in a SET expression - /// e.g. `SET OFFSETS SELECT, FROM, ORDER, TABLE, PROCEDURE, EXECUTE ON` - fn supports_set_multiple_values(&self) -> bool { - false - } - /// Returns true if the dialects supports specifying null treatment /// as part of a window function's parameter list as opposed /// to after the parameter list. diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index ba4d78826..3db34748e 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -58,10 +58,6 @@ impl Dialect for MsSqlDialect { true } - fn supports_set_multiple_values(&self) -> bool { - true - } - fn supports_try_convert(&self) -> bool { true } diff --git a/src/keywords.rs b/src/keywords.rs index f40ec6d36..bda817df9 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -173,7 +173,6 @@ define_keywords!( CHANNEL, CHAR, CHARACTER, - CHARACTERISTIC, CHARACTERS, CHARACTER_LENGTH, CHARSET, @@ -558,7 +557,6 @@ define_keywords!( MULTISET, MUTATION, NAME, - NAMES, NANOSECOND, NANOSECONDS, NATIONAL, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 4e9cd3d43..094329e97 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10992,17 +10992,66 @@ impl<'a> Parser<'a> { } } - pub fn parse_set(&mut self) -> Result { + fn parse_set_assignment( + &mut self, + ) -> Result<(OneOrManyWithParens, Expr), ParserError> { + let variables = if self.dialect.supports_parenthesized_set_variables() + && self.consume_token(&Token::LParen) + { + let vars = OneOrManyWithParens::Many( + self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())? + .into_iter() + .map(|ident| ObjectName::from(vec![ident])) + .collect(), + ); + self.expect_token(&Token::RParen)?; + vars + } else { + OneOrManyWithParens::One(self.parse_object_name(false)?) + }; + + if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) { + return self.expected("assignment operator", self.peek_token()); + } + + let values = self.parse_expr()?; + + Ok((variables, values)) + } + + fn parse_set(&mut self) -> Result { let modifier = self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]); + if let Some(Keyword::HIVEVAR) = modifier { self.expect_token(&Token::Colon)?; - } else if let Some(set_role_stmt) = - self.maybe_parse(|parser| parser.parse_set_role(modifier))? - { + } + + if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? { return Ok(set_role_stmt); } + if self.dialect.supports_comma_separated_set_assignments() { + if let Ok(v) = self + .try_parse(|parser| Ok(parser.parse_comma_separated(Parser::parse_set_assignment)?)) + { + let (variables, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); + + let variables = if variables.len() == 1 { + variables.into_iter().next().unwrap() + } else { + OneOrManyWithParens::Many(variables.into_iter().flatten().map(|v| v).collect()) + }; + + return Ok(Statement::SetVariable { + local: modifier == Some(Keyword::LOCAL), + hivevar: modifier == Some(Keyword::HIVEVAR), + variables, + value: values, + }); + } + } + let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) { OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])) } else if self.dialect.supports_parenthesized_set_variables() @@ -11022,7 +11071,7 @@ impl<'a> Parser<'a> { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_)); - let values = self.parse_set_values(parenthesized_assignment); + let values = self.parse_set_values(parenthesized_assignment)?; return Ok(Statement::SetVariable { local: modifier == Some(Keyword::LOCAL), diff --git a/t.sql b/t.sql deleted file mode 100644 index a822673c4..000000000 --- a/t.sql +++ /dev/null @@ -1 +0,0 @@ -SET TIME ZONE TO 'UTC' diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 218185f03..8c0622d63 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8618,10 +8618,10 @@ fn parse_set_variable() { "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", "SET (a) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)))", ), - // ( - // "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), SELECT 33 FROM tbl3)", - // "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), (SELECT 33 FROM tbl3))", - // ), + ( + "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), SELECT 33 FROM tbl3)", + "SET (a, b) = ((SELECT 22 FROM tbl1, (SELECT 1 FROM tbl2)), (SELECT 33 FROM tbl3))", + ), ] { multi_variable_dialects.one_statement_parses_to(sql, canonical); } @@ -14653,8 +14653,8 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { Statement::SetVariable { variables, value, .. } => { - assert_eq!(variables.len(), 2); assert_eq!(value.len(), 2); + assert_eq!(variables.len(), 2); } _ => assert!(false, "Expected SetVariable with 2 variables and 2 values"), }; From 13f73d16d300ea68c58311f9e2e18b3fe1e5bd2b Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Fri, 7 Mar 2025 00:13:05 +0200 Subject: [PATCH 03/11] fix clippy --- src/parser/mod.rs | 24 ++++++++++++------------ tests/sqlparser_common.rs | 3 +-- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 094329e97..863b9429d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11033,14 +11033,14 @@ impl<'a> Parser<'a> { if self.dialect.supports_comma_separated_set_assignments() { if let Ok(v) = self - .try_parse(|parser| Ok(parser.parse_comma_separated(Parser::parse_set_assignment)?)) + .try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) { let (variables, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); let variables = if variables.len() == 1 { variables.into_iter().next().unwrap() } else { - OneOrManyWithParens::Many(variables.into_iter().flatten().map(|v| v).collect()) + OneOrManyWithParens::Many(variables.into_iter().flatten().collect()) }; return Ok(Statement::SetVariable { @@ -11097,27 +11097,27 @@ impl<'a> Parser<'a> { None }; - return Ok(Statement::SetNames { + Ok(Statement::SetNames { charset_name, collation_name, - }); + }) } "TIMEZONE" => match self.parse_expr() { Ok(expr) => { - return Ok(Statement::SetTimeZone { + Ok(Statement::SetTimeZone { local: modifier == Some(Keyword::LOCAL), value: expr, }) } - _ => return self.expected("timezone value", self.peek_token()), + _ => self.expected("timezone value", self.peek_token()), }, "CHARACTERISTICS" => { self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; - return Ok(Statement::SetTransaction { + Ok(Statement::SetTransaction { modes: self.parse_transaction_modes()?, snapshot: None, session: true, - }); + }) } "TRANSACTION" if modifier.is_none() => { if self.parse_keyword(Keyword::SNAPSHOT) { @@ -11128,17 +11128,17 @@ impl<'a> Parser<'a> { session: false, }); } - return Ok(Statement::SetTransaction { + Ok(Statement::SetTransaction { modes: self.parse_transaction_modes()?, snapshot: None, session: false, - }); + }) } _ if self.dialect.supports_set_stmt_without_operator() => { self.prev_token(); - return self.parse_set_session_params(); + self.parse_set_session_params() } - _ => return self.expected("equals sign or TO", self.peek_token()), + _ => self.expected("equals sign or TO", self.peek_token()), } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 8c0622d63..2358fc64c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14648,7 +14648,6 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { let stmt = stmt[0].clone(); - assert!(matches!(stmt, Statement::SetVariable { .. })); match stmt { Statement::SetVariable { variables, value, .. @@ -14656,7 +14655,7 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { assert_eq!(value.len(), 2); assert_eq!(variables.len(), 2); } - _ => assert!(false, "Expected SetVariable with 2 variables and 2 values"), + _ => panic!("Expected SetVariable with 2 variables and 2 values"), }; Ok(()) From a68f8d09223fe24df9ae9afe51c0b847cedc6b5d Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Fri, 7 Mar 2025 00:14:48 +0200 Subject: [PATCH 04/11] fix cargo fmt --- src/parser/mod.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 863b9429d..976d808fd 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11032,8 +11032,8 @@ impl<'a> Parser<'a> { } if self.dialect.supports_comma_separated_set_assignments() { - if let Ok(v) = self - .try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) + if let Ok(v) = + self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) { let (variables, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); @@ -11103,12 +11103,10 @@ impl<'a> Parser<'a> { }) } "TIMEZONE" => match self.parse_expr() { - Ok(expr) => { - Ok(Statement::SetTimeZone { - local: modifier == Some(Keyword::LOCAL), - value: expr, - }) - } + Ok(expr) => Ok(Statement::SetTimeZone { + local: modifier == Some(Keyword::LOCAL), + value: expr, + }), _ => self.expected("timezone value", self.peek_token()), }, "CHARACTERISTICS" => { From 221c4cf1b4de749ee36be73c85a070efa85fcfbe Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Fri, 7 Mar 2025 21:50:03 +0200 Subject: [PATCH 05/11] refactor, move special set cases to the top --- src/keywords.rs | 2 + src/parser/mod.rs | 123 ++++++++++++++++++++++++---------------------- 2 files changed, 65 insertions(+), 60 deletions(-) diff --git a/src/keywords.rs b/src/keywords.rs index bda817df9..195bbb172 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -173,6 +173,7 @@ define_keywords!( CHANNEL, CHAR, CHARACTER, + CHARACTERISTICS, CHARACTERS, CHARACTER_LENGTH, CHARSET, @@ -557,6 +558,7 @@ define_keywords!( MULTISET, MUTATION, NAME, + NAMES, NANOSECOND, NANOSECONDS, NATIONAL, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 976d808fd..b81617cb7 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11031,6 +11031,62 @@ impl<'a> Parser<'a> { return Ok(set_role_stmt); } + // Handle special cases first + if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) + || self.parse_keyword(Keyword::TIMEZONE) + { + if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { + return Ok(Statement::SetVariable { + local: modifier == Some(Keyword::LOCAL), + hivevar: modifier == Some(Keyword::HIVEVAR), + variables: OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])), + value: self.parse_set_values(false)?, + }); + } + + // Special case for Postgres + return Ok(Statement::SetTimeZone { + local: modifier == Some(Keyword::LOCAL), + value: self.parse_expr()?, + }); + } else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) { + if self.parse_keyword(Keyword::DEFAULT) { + return Ok(Statement::SetNamesDefault {}); + } + let charset_name = self.parse_identifier()?; + let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { + Some(self.parse_literal_string()?) + } else { + None + }; + + return Ok(Statement::SetNames { + charset_name, + collation_name, + }); + } else if self.parse_keyword(Keyword::CHARACTERISTICS) { + self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; + return Ok(Statement::SetTransaction { + modes: self.parse_transaction_modes()?, + snapshot: None, + session: true, + }); + } else if self.parse_keyword(Keyword::TRANSACTION) { + if self.parse_keyword(Keyword::SNAPSHOT) { + let snapshot_id = self.parse_value()?.value; + return Ok(Statement::SetTransaction { + modes: vec![], + snapshot: Some(snapshot_id), + session: false, + }); + } + return Ok(Statement::SetTransaction { + modes: self.parse_transaction_modes()?, + snapshot: None, + session: false, + }); + } + if self.dialect.supports_comma_separated_set_assignments() { if let Ok(v) = self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) @@ -11052,19 +11108,17 @@ impl<'a> Parser<'a> { } } - let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) { - OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])) - } else if self.dialect.supports_parenthesized_set_variables() + let variables = if self.dialect.supports_parenthesized_set_variables() && self.consume_token(&Token::LParen) { - let variables = OneOrManyWithParens::Many( + let vars = OneOrManyWithParens::Many( self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())? .into_iter() .map(|ident| ObjectName::from(vec![ident])) .collect(), ); self.expect_token(&Token::RParen)?; - variables + vars } else { OneOrManyWithParens::One(self.parse_object_name(false)?) }; @@ -11081,63 +11135,12 @@ impl<'a> Parser<'a> { }); } - let OneOrManyWithParens::One(variable) = variables else { - return self.expected("set variable", self.peek_token()); + if self.dialect.supports_set_stmt_without_operator() { + self.prev_token(); + return self.parse_set_session_params(); }; - match variable.to_string().to_ascii_uppercase().as_str() { - "NAMES" if self.dialect.supports_set_names() => { - if self.parse_keyword(Keyword::DEFAULT) { - return Ok(Statement::SetNamesDefault {}); - } - let charset_name = self.parse_identifier()?; - let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { - Some(self.parse_literal_string()?) - } else { - None - }; - - Ok(Statement::SetNames { - charset_name, - collation_name, - }) - } - "TIMEZONE" => match self.parse_expr() { - Ok(expr) => Ok(Statement::SetTimeZone { - local: modifier == Some(Keyword::LOCAL), - value: expr, - }), - _ => self.expected("timezone value", self.peek_token()), - }, - "CHARACTERISTICS" => { - self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; - Ok(Statement::SetTransaction { - modes: self.parse_transaction_modes()?, - snapshot: None, - session: true, - }) - } - "TRANSACTION" if modifier.is_none() => { - if self.parse_keyword(Keyword::SNAPSHOT) { - let snapshot_id = self.parse_value()?.value; - return Ok(Statement::SetTransaction { - modes: vec![], - snapshot: Some(snapshot_id), - session: false, - }); - } - Ok(Statement::SetTransaction { - modes: self.parse_transaction_modes()?, - snapshot: None, - session: false, - }) - } - _ if self.dialect.supports_set_stmt_without_operator() => { - self.prev_token(); - self.parse_set_session_params() - } - _ => self.expected("equals sign or TO", self.peek_token()), - } + self.expected("equals sign or TO", self.peek_token()) } pub fn parse_set_session_params(&mut self) -> Result { From 36f54abb3ddb4938e70d481984a42a05f74ee84d Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Fri, 7 Mar 2025 23:06:10 +0200 Subject: [PATCH 06/11] enum to represent comma-sep list of SETs --- src/ast/mod.rs | 21 +++++++++++++++++++++ src/ast/spans.rs | 1 + src/parser/mod.rs | 33 ++++++++++++++++++++++----------- tests/sqlparser_common.rs | 10 +++------- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index e5e4aef05..5f15357d5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2947,6 +2947,17 @@ pub enum Statement { variables: OneOrManyWithParens, value: Vec, }, + + /// ```sql + /// SET = expression [, = expression]*; + /// ``` + /// + /// Note: this is a MySQL-specific statement. + /// Refer to [`Dialect.supports_comma_separated_set_assignments`] + SetVariables { + variables: Vec, + values: Vec, + }, /// ```sql /// SET TIME ZONE /// ``` @@ -5334,6 +5345,16 @@ impl fmt::Display for Statement { Statement::List(command) => write!(f, "LIST {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"), Statement::SetSessionParam(kind) => write!(f, "SET {kind}"), + Statement::SetVariables { variables, values } => { + write!(f, "SET ")?; + variables + .iter() + .zip(values.iter()) + .map(|(var, val)| format!("{var} = {val}")) + .collect::>() + .join(", ") + .fmt(f) + } } } } diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 0a64fb8ea..6db630742 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -509,6 +509,7 @@ impl Spanned for Statement { Statement::RaisError { .. } => Span::empty(), Statement::List(..) | Statement::Remove(..) => Span::empty(), Statement::SetSessionParam { .. } => Span::empty(), + Statement::SetVariables { .. } => Span::empty(), } } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b81617cb7..5f0f058df 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11091,20 +11091,31 @@ impl<'a> Parser<'a> { if let Ok(v) = self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) { - let (variables, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); + let (vars, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); + + return if vars.len() > 1 { + let variables = vars + .into_iter() + .map(|v| match v { + OneOrManyWithParens::One(v) => Ok(v), + _ => self.expected("List of single identifiers", self.peek_token()), + }) + .collect::>()?; - let variables = if variables.len() == 1 { - variables.into_iter().next().unwrap() + Ok(Statement::SetVariables { variables, values }) } else { - OneOrManyWithParens::Many(variables.into_iter().flatten().collect()) + let variable = match vars.into_iter().next() { + Some(v) => Ok(v), + None => self.expected("At least one identifier", self.peek_token()), + }?; + + Ok(Statement::SetVariable { + local: modifier == Some(Keyword::LOCAL), + hivevar: modifier == Some(Keyword::HIVEVAR), + variables: variable, + value: values, + }) }; - - return Ok(Statement::SetVariable { - local: modifier == Some(Keyword::LOCAL), - hivevar: modifier == Some(Keyword::HIVEVAR), - variables, - value: values, - }); } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 2358fc64c..cfb7e8929 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14644,15 +14644,11 @@ fn parse_set_names() { #[test] fn parse_multiple_set_statements() -> Result<(), ParserError> { let dialects = all_dialects_where(|d| d.supports_comma_separated_set_assignments()); - let stmt = dialects.parse_sql_statements("SET @a = 1, b = 2")?; - - let stmt = stmt[0].clone(); + let stmt = dialects.verified_stmt("SET @a = 1, b = 2"); match stmt { - Statement::SetVariable { - variables, value, .. - } => { - assert_eq!(value.len(), 2); + Statement::SetVariables { variables, values } => { + assert_eq!(values.len(), 2); assert_eq!(variables.len(), 2); } _ => panic!("Expected SetVariable with 2 variables and 2 values"), From 161658ac0eabf7889959c4409607ba49cda1e1e2 Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Fri, 7 Mar 2025 23:44:18 +0200 Subject: [PATCH 07/11] set time zone shorthand --- src/parser/mod.rs | 14 ++++++++------ tests/sqlparser_common.rs | 8 ++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5f0f058df..52778ebcb 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11042,13 +11042,15 @@ impl<'a> Parser<'a> { variables: OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])), value: self.parse_set_values(false)?, }); + } else if self.dialect.is::() { + // Special case for Postgres + return Ok(Statement::SetTimeZone { + local: modifier == Some(Keyword::LOCAL), + value: self.parse_expr()?, + }); + } else { + return self.expected("assignment operator", self.peek_token()); } - - // Special case for Postgres - return Ok(Statement::SetTimeZone { - local: modifier == Some(Keyword::LOCAL), - value: self.parse_expr()?, - }); } else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) { if self.parse_keyword(Keyword::DEFAULT) { return Ok(Statement::SetNamesDefault {}); diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index cfb7e8929..6c9e6d481 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14656,3 +14656,11 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { Ok(()) } + +#[test] +fn parse_set_time_zone_alias() { + // not sure what other dialects support this + all_dialects_but_pg() + .parse_sql_statements("SET TIME ZONE 'UTC'") + .unwrap_err(); +} From 88e56665295782dda3a5376eaccd04f31f9789f0 Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Sat, 8 Mar 2025 00:39:14 +0200 Subject: [PATCH 08/11] use display_comma_separated --- src/ast/mod.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5f15357d5..028ab6c74 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -5345,16 +5345,17 @@ impl fmt::Display for Statement { Statement::List(command) => write!(f, "LIST {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"), Statement::SetSessionParam(kind) => write!(f, "SET {kind}"), - Statement::SetVariables { variables, values } => { - write!(f, "SET ")?; - variables - .iter() - .zip(values.iter()) - .map(|(var, val)| format!("{var} = {val}")) - .collect::>() - .join(", ") - .fmt(f) - } + Statement::SetVariables { variables, values } => write!( + f, + "SET {}", + display_comma_separated( + &variables + .iter() + .zip(values.iter()) + .map(|(var, val)| format!("{var} = {val}")) + .collect::>() + ) + ), } } } From 4ac5b00c1698b4c798536d663662896e08cfca58 Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Mon, 10 Mar 2025 22:51:19 +0200 Subject: [PATCH 09/11] Extract Set statements into their own enum --- src/ast/dml.rs | 14 +- src/ast/mod.rs | 394 ++++++++++++++++++++---------------- src/ast/spans.rs | 48 ++--- src/parser/mod.rs | 140 ++++++++----- tests/sqlparser_bigquery.rs | 4 +- tests/sqlparser_common.rs | 95 ++++----- tests/sqlparser_hive.rs | 15 +- tests/sqlparser_mssql.rs | 8 +- tests/sqlparser_mysql.rs | 34 ++-- tests/sqlparser_postgres.rs | 94 ++++----- tests/sqlparser_sqlite.rs | 2 +- 11 files changed, 454 insertions(+), 394 deletions(-) diff --git a/src/ast/dml.rs b/src/ast/dml.rs index ccea7fbcb..bc3a9546d 100644 --- a/src/ast/dml.rs +++ b/src/ast/dml.rs @@ -32,12 +32,12 @@ use sqlparser_derive::{Visit, VisitMut}; pub use super::ddl::{ColumnDef, TableConstraint}; use super::{ - display_comma_separated, display_separated, query::InputFormatClause, Assignment, ClusteredBy, - CommentDef, Expr, FileFormat, FromTable, HiveDistributionStyle, HiveFormat, HiveIOFormat, - HiveRowFormat, Ident, IndexType, InsertAliases, MysqlInsertPriority, ObjectName, OnCommit, - OnInsert, OneOrManyWithParens, OrderByExpr, Query, RowAccessPolicy, SelectItem, Setting, - SqlOption, SqliteOnConflict, StorageSerializationPolicy, TableEngine, TableObject, - TableWithJoins, Tag, WrappedCollection, + display_comma_separated, display_separated, query::InputFormatClause, ClusteredBy, CommentDef, + Expr, FileFormat, FromTable, HiveDistributionStyle, HiveFormat, HiveIOFormat, HiveRowFormat, + Ident, IndexType, InsertAliases, MysqlInsertPriority, ObjectName, OnCommit, OnInsert, + OneOrManyWithParens, OrderByExpr, Query, RowAccessPolicy, SelectItem, Setting, SqlOption, + SqliteOnConflict, StorageSerializationPolicy, TableEngine, TableObject, TableWithJoins, Tag, + UpdateAssignment, WrappedCollection, }; /// Index column type. @@ -544,7 +544,7 @@ pub struct Insert { pub source: Option>, /// MySQL `INSERT INTO ... SET` /// See: - pub assignments: Vec, + pub assignments: Vec, /// partitioned insert (Hive) pub partitioned: Option>, /// Columns defined after PARTITION diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 028ab6c74..ab08961fb 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2394,6 +2394,167 @@ pub enum CreatePolicyCommand { Delete, } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum Set { + /// SQL Standard-style + /// SET a = 1; + SingleAssignment { + local: bool, + hivevar: bool, + variable: ObjectName, + values: Vec, + }, + /// Snowflake-style + /// SET (a, b, ..) = (1, 2, ..); + ParenthesizedAssignments { + variables: Vec, + values: Vec, + }, + /// MySQL-style + /// SET a = 1, b = 2, ..; + MultipleAssignments { assignments: Vec }, + /// MS-SQL session + /// + /// See + SetSessionParam(SetSessionParamKind), + /// ```sql + /// SET [ SESSION | LOCAL ] ROLE role_name + /// ``` + /// + /// Sets session state. Examples: [ANSI][1], [Postgresql][2], [MySQL][3], and [Oracle][4] + /// + /// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#set-role-statement + /// [2]: https://www.postgresql.org/docs/14/sql-set-role.html + /// [3]: https://dev.mysql.com/doc/refman/8.0/en/set-role.html + /// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm + SetRole { + /// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`). + context_modifier: ContextModifier, + /// Role name. If NONE is specified, then the current role name is removed. + role_name: Option, + }, + /// ```sql + /// SET TIME ZONE + /// ``` + /// + /// Note: this is a PostgreSQL-specific statements + /// `SET TIME ZONE ` is an alias for `SET timezone TO ` in PostgreSQL + SetTimeZone { local: bool, value: Expr }, + /// ```sql + /// SET NAMES 'charset_name' [COLLATE 'collation_name'] + /// ``` + SetNames { + charset_name: Ident, + collation_name: Option, + }, + /// ```sql + /// SET NAMES DEFAULT + /// ``` + /// + /// Note: this is a MySQL-specific statement. + SetNamesDefault {}, + /// ```sql + /// SET TRANSACTION ... + /// ``` + SetTransaction { + modes: Vec, + snapshot: Option, + session: bool, + }, +} + +impl Display for Set { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::ParenthesizedAssignments { variables, values } => write!( + f, + "SET ({}) = ({})", + display_comma_separated(variables), + display_comma_separated(values) + ), + Self::MultipleAssignments { assignments } => { + write!(f, "SET {}", display_comma_separated(assignments)) + } + Self::SetRole { + context_modifier, + role_name, + } => { + let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE")); + write!(f, "SET{context_modifier} ROLE {role_name}") + } + Self::SetSessionParam(kind) => write!(f, "SET {kind}"), + Self::SetTransaction { + modes, + snapshot, + session, + } => { + if *session { + write!(f, "SET SESSION CHARACTERISTICS AS TRANSACTION")?; + } else { + write!(f, "SET TRANSACTION")?; + } + if !modes.is_empty() { + write!(f, " {}", display_comma_separated(modes))?; + } + if let Some(snapshot_id) = snapshot { + write!(f, " SNAPSHOT {snapshot_id}")?; + } + Ok(()) + } + Self::SetTimeZone { local, value } => { + f.write_str("SET ")?; + if *local { + f.write_str("LOCAL ")?; + } + write!(f, "TIME ZONE {value}") + } + Self::SetNames { + charset_name, + collation_name, + } => { + write!(f, "SET NAMES {}", charset_name)?; + + if let Some(collation) = collation_name { + f.write_str(" COLLATE ")?; + f.write_str(collation)?; + }; + + Ok(()) + } + Self::SetNamesDefault {} => { + f.write_str("SET NAMES DEFAULT")?; + + Ok(()) + } + Set::SingleAssignment { + local, + hivevar, + variable, + values, + } => { + write!( + f, + "SET {}{}{} = {}", + if *local { "LOCAL " } else { "" }, + if *hivevar { "HIVEVAR:" } else { "" }, + variable, + display_comma_separated(values) + ) + } + } + } +} + +/// Convert a `Set` into a `Statement`. +/// Convenience function, instead of writing `Statement::Set(Set::Set...{...})` +impl From for Statement { + fn from(set: Set) -> Self { + Statement::Set(set) + } +} + /// A top-level statement (SELECT, INSERT, CREATE, etc.) #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -2419,6 +2580,7 @@ pub enum Statement { compute_statistics: bool, has_table_keyword: bool, }, + Set(Set), /// ```sql /// TRUNCATE /// ``` @@ -2545,7 +2707,7 @@ pub enum Statement { /// TABLE table: TableWithJoins, /// Column assignments - assignments: Vec, + assignments: Vec, /// Table which provide value to be set from: Option, /// WHERE @@ -2846,7 +3008,10 @@ pub enum Statement { /// DROP CONNECTOR /// ``` /// See [Hive](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362034#LanguageManualDDL-DropConnector) - DropConnector { if_exists: bool, name: Ident }, + DropConnector { + if_exists: bool, + name: Ident, + }, /// ```sql /// DECLARE /// ``` @@ -2854,7 +3019,9 @@ pub enum Statement { /// /// Note: this is a PostgreSQL-specific statement, /// but may also compatible with other SQL. - Declare { stmts: Vec }, + Declare { + stmts: Vec, + }, /// ```sql /// CREATE EXTENSION [ IF NOT EXISTS ] extension_name /// [ WITH ] [ SCHEMA schema_name ] @@ -2916,78 +3083,23 @@ pub enum Statement { /// /// Note: this is a PostgreSQL-specific statement, /// but may also compatible with other SQL. - Discard { object_type: DiscardObject }, - /// ```sql - /// SET [ SESSION | LOCAL ] ROLE role_name - /// ``` - /// - /// Sets session state. Examples: [ANSI][1], [Postgresql][2], [MySQL][3], and [Oracle][4] - /// - /// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#set-role-statement - /// [2]: https://www.postgresql.org/docs/14/sql-set-role.html - /// [3]: https://dev.mysql.com/doc/refman/8.0/en/set-role.html - /// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm - SetRole { - /// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`). - context_modifier: ContextModifier, - /// Role name. If NONE is specified, then the current role name is removed. - role_name: Option, - }, - /// ```sql - /// SET = expression; - /// SET (variable[, ...]) = (expression[, ...]); - /// ``` - /// - /// Note: this is not a standard SQL statement, but it is supported by at - /// least MySQL and PostgreSQL. Not all MySQL-specific syntactic forms are - /// supported yet. - SetVariable { - local: bool, - hivevar: bool, - variables: OneOrManyWithParens, - value: Vec, + Discard { + object_type: DiscardObject, }, - - /// ```sql - /// SET = expression [, = expression]*; - /// ``` - /// - /// Note: this is a MySQL-specific statement. - /// Refer to [`Dialect.supports_comma_separated_set_assignments`] - SetVariables { - variables: Vec, - values: Vec, - }, - /// ```sql - /// SET TIME ZONE - /// ``` - /// - /// Note: this is a PostgreSQL-specific statements - /// `SET TIME ZONE ` is an alias for `SET timezone TO ` in PostgreSQL - SetTimeZone { local: bool, value: Expr }, - /// ```sql - /// SET NAMES 'charset_name' [COLLATE 'collation_name'] - /// ``` - SetNames { - charset_name: Ident, - collation_name: Option, - }, - /// ```sql - /// SET NAMES DEFAULT - /// ``` - /// - /// Note: this is a MySQL-specific statement. - SetNamesDefault {}, /// `SHOW FUNCTIONS` /// /// Note: this is a Presto-specific statement. - ShowFunctions { filter: Option }, + ShowFunctions { + filter: Option, + }, /// ```sql /// SHOW /// ``` /// /// Note: this is a PostgreSQL-specific statement. - ShowVariable { variable: Vec }, + ShowVariable { + variable: Vec, + }, /// ```sql /// SHOW [GLOBAL | SESSION] STATUS [LIKE 'pattern' | WHERE expr] /// ``` @@ -3071,7 +3183,9 @@ pub enum Statement { /// ``` /// /// Note: this is a MySQL-specific statement. - ShowCollation { filter: Option }, + ShowCollation { + filter: Option, + }, /// ```sql /// `USE ...` /// ``` @@ -3114,14 +3228,6 @@ pub enum Statement { has_end_keyword: bool, }, /// ```sql - /// SET TRANSACTION ... - /// ``` - SetTransaction { - modes: Vec, - snapshot: Option, - session: bool, - }, - /// ```sql /// COMMENT ON ... /// ``` /// @@ -3340,7 +3446,10 @@ pub enum Statement { /// ``` /// /// Note: this is a PostgreSQL-specific statement. - Deallocate { name: Ident, prepare: bool }, + Deallocate { + name: Ident, + prepare: bool, + }, /// ```sql /// An `EXECUTE` statement /// ``` @@ -3426,11 +3535,15 @@ pub enum Statement { /// SAVEPOINT /// ``` /// Define a new savepoint within the current transaction - Savepoint { name: Ident }, + Savepoint { + name: Ident, + }, /// ```sql /// RELEASE [ SAVEPOINT ] savepoint_name /// ``` - ReleaseSavepoint { name: Ident }, + ReleaseSavepoint { + name: Ident, + }, /// A `MERGE` statement. /// /// ```sql @@ -3510,7 +3623,9 @@ pub enum Statement { /// LOCK TABLES [READ [LOCAL] | [LOW_PRIORITY] WRITE] /// ``` /// Note: this is a MySQL-specific statement. See - LockTables { tables: Vec }, + LockTables { + tables: Vec, + }, /// ```sql /// UNLOCK TABLES /// ``` @@ -3544,14 +3659,18 @@ pub enum Statement { /// listen for a notification channel /// /// See Postgres - LISTEN { channel: Ident }, + LISTEN { + channel: Ident, + }, /// ```sql /// UNLISTEN /// ``` /// stop listening for a notification /// /// See Postgres - UNLISTEN { channel: Ident }, + UNLISTEN { + channel: Ident, + }, /// ```sql /// NOTIFY channel [ , payload ] /// ``` @@ -3591,10 +3710,6 @@ pub enum Statement { /// Snowflake `REMOVE` /// See: Remove(FileStagingCommand), - /// MS-SQL session - /// - /// See - SetSessionParam(SetSessionParamKind), /// RaiseError (MSSQL) /// RAISERROR ( { msg_id | msg_str | @local_variable } /// { , severity , state } @@ -4655,59 +4770,7 @@ impl fmt::Display for Statement { write!(f, "DISCARD {object_type}")?; Ok(()) } - Self::SetRole { - context_modifier, - role_name, - } => { - let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE")); - write!(f, "SET{context_modifier} ROLE {role_name}") - } - Statement::SetVariable { - local, - variables, - hivevar, - value, - } => { - f.write_str("SET ")?; - if *local { - f.write_str("LOCAL ")?; - } - let parenthesized = matches!(variables, OneOrManyWithParens::Many(_)); - write!( - f, - "{hivevar}{name} = {l_paren}{value}{r_paren}", - hivevar = if *hivevar { "HIVEVAR:" } else { "" }, - name = variables, - l_paren = parenthesized.then_some("(").unwrap_or_default(), - value = display_comma_separated(value), - r_paren = parenthesized.then_some(")").unwrap_or_default(), - ) - } - Statement::SetTimeZone { local, value } => { - f.write_str("SET ")?; - if *local { - f.write_str("LOCAL ")?; - } - write!(f, "TIME ZONE {value}") - } - Statement::SetNames { - charset_name, - collation_name, - } => { - write!(f, "SET NAMES {}", charset_name)?; - - if let Some(collation) = collation_name { - f.write_str(" COLLATE ")?; - f.write_str(collation)?; - }; - - Ok(()) - } - Statement::SetNamesDefault {} => { - f.write_str("SET NAMES DEFAULT")?; - - Ok(()) - } + Self::Set(set) => write!(f, "{set}"), Statement::ShowVariable { variable } => { write!(f, "SHOW")?; if !variable.is_empty() { @@ -4896,24 +4959,6 @@ impl fmt::Display for Statement { } Ok(()) } - Statement::SetTransaction { - modes, - snapshot, - session, - } => { - if *session { - write!(f, "SET SESSION CHARACTERISTICS AS TRANSACTION")?; - } else { - write!(f, "SET TRANSACTION")?; - } - if !modes.is_empty() { - write!(f, " {}", display_comma_separated(modes))?; - } - if let Some(snapshot_id) = snapshot { - write!(f, " SNAPSHOT {snapshot_id}")?; - } - Ok(()) - } Statement::Commit { chain, end: end_syntax, @@ -5344,18 +5389,6 @@ impl fmt::Display for Statement { Statement::List(command) => write!(f, "LIST {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"), - Statement::SetSessionParam(kind) => write!(f, "SET {kind}"), - Statement::SetVariables { variables, values } => write!( - f, - "SET {}", - display_comma_separated( - &variables - .iter() - .zip(values.iter()) - .map(|(var, val)| format!("{var} = {val}")) - .collect::>() - ) - ), } } } @@ -5419,6 +5452,21 @@ impl fmt::Display for SequenceOptions { } } +/// Assignment for a `SET` statement (name [=|TO] value) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct SetAssignment { + pub name: ObjectName, + pub value: Expr, +} + +impl fmt::Display for SetAssignment { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} = {}", self.name, self.value) + } +} + /// Target of a `TRUNCATE TABLE` command /// /// Note this is its own struct because `visit_relation` requires an `ObjectName` (not a `Vec`) @@ -5504,7 +5552,7 @@ pub enum MinMaxValue { #[non_exhaustive] pub enum OnInsert { /// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead) - DuplicateKeyUpdate(Vec), + DuplicateKeyUpdate(Vec), /// ON CONFLICT is a PostgreSQL and Sqlite extension OnConflict(OnConflict), } @@ -5544,7 +5592,7 @@ pub enum OnConflictAction { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct DoUpdate { /// Column assignments - pub assignments: Vec, + pub assignments: Vec, /// WHERE pub selection: Option, } @@ -6170,12 +6218,12 @@ impl fmt::Display for GrantObjects { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct Assignment { +pub struct UpdateAssignment { pub target: AssignmentTarget, pub value: Expr, } -impl fmt::Display for Assignment { +impl fmt::Display for UpdateAssignment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} = {}", self.target, self.value) } @@ -7504,7 +7552,7 @@ pub enum MergeAction { /// ```sql /// UPDATE SET quantity = T.quantity + S.quantity /// ``` - Update { assignments: Vec }, + Update { assignments: Vec }, /// A plain `DELETE` clause Delete, } diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 6db630742..f74e615f6 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -22,20 +22,20 @@ use crate::tokenizer::Span; use super::{ dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, - AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor, - ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, - ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, - Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, - Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, - FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, - InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, - MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, - OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, - PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem, - ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, - Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, - TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, - Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, + AlterIndexOperation, AlterTableOperation, Array, AssignmentTarget, CloseCursor, ClusteredIndex, + ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, ConstraintCharacteristics, + CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate, + ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function, + FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, + GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, + JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, MatchRecognizePattern, + Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, + OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, + ProjectionSelect, Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, + ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, + SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject, + TableOptionsClustered, TableWithJoins, UpdateAssignment, UpdateTableFromKind, Use, Value, + Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -229,11 +229,7 @@ impl Spanned for Values { /// - [Statement::Fetch] /// - [Statement::Flush] /// - [Statement::Discard] -/// - [Statement::SetRole] -/// - [Statement::SetVariable] -/// - [Statement::SetTimeZone] -/// - [Statement::SetNames] -/// - [Statement::SetNamesDefault] +/// - [Statement::Set] /// - [Statement::ShowFunctions] /// - [Statement::ShowVariable] /// - [Statement::ShowStatus] @@ -243,7 +239,6 @@ impl Spanned for Values { /// - [Statement::ShowTables] /// - [Statement::ShowCollation] /// - [Statement::StartTransaction] -/// - [Statement::SetTransaction] /// - [Statement::Comment] /// - [Statement::Commit] /// - [Statement::Rollback] @@ -444,11 +439,7 @@ impl Spanned for Statement { Statement::Fetch { .. } => Span::empty(), Statement::Flush { .. } => Span::empty(), Statement::Discard { .. } => Span::empty(), - Statement::SetRole { .. } => Span::empty(), - Statement::SetVariable { .. } => Span::empty(), - Statement::SetTimeZone { .. } => Span::empty(), - Statement::SetNames { .. } => Span::empty(), - Statement::SetNamesDefault {} => Span::empty(), + Statement::Set(_) => Span::empty(), Statement::ShowFunctions { .. } => Span::empty(), Statement::ShowVariable { .. } => Span::empty(), Statement::ShowStatus { .. } => Span::empty(), @@ -459,7 +450,6 @@ impl Spanned for Statement { Statement::ShowCollation { .. } => Span::empty(), Statement::Use(u) => u.span(), Statement::StartTransaction { .. } => Span::empty(), - Statement::SetTransaction { .. } => Span::empty(), Statement::Comment { .. } => Span::empty(), Statement::Commit { .. } => Span::empty(), Statement::Rollback { .. } => Span::empty(), @@ -508,8 +498,6 @@ impl Spanned for Statement { Statement::RenameTable { .. } => Span::empty(), Statement::RaisError { .. } => Span::empty(), Statement::List(..) | Statement::Remove(..) => Span::empty(), - Statement::SetSessionParam { .. } => Span::empty(), - Statement::SetVariables { .. } => Span::empty(), } } } @@ -1256,9 +1244,9 @@ impl Spanned for DoUpdate { } } -impl Spanned for Assignment { +impl Spanned for UpdateAssignment { fn span(&self) -> Span { - let Assignment { target, value } = self; + let UpdateAssignment { target, value } = self; target.span().union(&value.span()) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 52778ebcb..74bc845d2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10955,10 +10955,10 @@ impl<'a> Parser<'a> { } else { Some(self.parse_identifier()?) }; - Ok(Statement::SetRole { + Ok(Statement::Set(Set::SetRole { context_modifier, role_name, - }) + })) } fn parse_set_values( @@ -11036,24 +11036,26 @@ impl<'a> Parser<'a> { || self.parse_keyword(Keyword::TIMEZONE) { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { - return Ok(Statement::SetVariable { + return Ok(Set::SingleAssignment { local: modifier == Some(Keyword::LOCAL), hivevar: modifier == Some(Keyword::HIVEVAR), - variables: OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])), - value: self.parse_set_values(false)?, - }); + variable: ObjectName::from(vec!["TIMEZONE".into()]), + values: self.parse_set_values(false)?, + } + .into()); } else if self.dialect.is::() { // Special case for Postgres - return Ok(Statement::SetTimeZone { + return Ok(Set::SetTimeZone { local: modifier == Some(Keyword::LOCAL), value: self.parse_expr()?, - }); + } + .into()); } else { return self.expected("assignment operator", self.peek_token()); } } else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) { if self.parse_keyword(Keyword::DEFAULT) { - return Ok(Statement::SetNamesDefault {}); + return Ok(Set::SetNamesDefault {}.into()); } let charset_name = self.parse_identifier()?; let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { @@ -11062,61 +11064,75 @@ impl<'a> Parser<'a> { None }; - return Ok(Statement::SetNames { + return Ok(Set::SetNames { charset_name, collation_name, - }); + } + .into()); } else if self.parse_keyword(Keyword::CHARACTERISTICS) { self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?; - return Ok(Statement::SetTransaction { + return Ok(Set::SetTransaction { modes: self.parse_transaction_modes()?, snapshot: None, session: true, - }); + } + .into()); } else if self.parse_keyword(Keyword::TRANSACTION) { if self.parse_keyword(Keyword::SNAPSHOT) { let snapshot_id = self.parse_value()?.value; - return Ok(Statement::SetTransaction { + return Ok(Set::SetTransaction { modes: vec![], snapshot: Some(snapshot_id), session: false, - }); + } + .into()); } - return Ok(Statement::SetTransaction { + return Ok(Set::SetTransaction { modes: self.parse_transaction_modes()?, snapshot: None, session: false, - }); + } + .into()); } if self.dialect.supports_comma_separated_set_assignments() { - if let Ok(v) = + if let Ok(assignments) = self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) { - let (vars, values): (Vec<_>, Vec<_>) = v.into_iter().unzip(); - - return if vars.len() > 1 { - let variables = vars + return if assignments.len() > 1 { + let assignments = assignments .into_iter() - .map(|v| match v { - OneOrManyWithParens::One(v) => Ok(v), - _ => self.expected("List of single identifiers", self.peek_token()), + .map(|(var, val)| match var { + OneOrManyWithParens::One(v) => Ok(SetAssignment { + name: v, + value: val, + }), + OneOrManyWithParens::Many(_) => { + self.expected("List of single identifiers", self.peek_token()) + } }) .collect::>()?; - Ok(Statement::SetVariables { variables, values }) + Ok(Set::MultipleAssignments { assignments }.into()) } else { + let (vars, values): (Vec<_>, Vec<_>) = assignments.into_iter().unzip(); + let variable = match vars.into_iter().next() { - Some(v) => Ok(v), + Some(OneOrManyWithParens::One(v)) => Ok(v), + Some(OneOrManyWithParens::Many(_)) => self.expected( + "Single assignment or list of assignments", + self.peek_token(), + ), None => self.expected("At least one identifier", self.peek_token()), }?; - Ok(Statement::SetVariable { + Ok(Set::SingleAssignment { local: modifier == Some(Keyword::LOCAL), hivevar: modifier == Some(Keyword::HIVEVAR), - variables: variable, - value: values, - }) + variable, + values, + } + .into()) }; } } @@ -11137,15 +11153,20 @@ impl<'a> Parser<'a> { }; if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { - let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_)); - let values = self.parse_set_values(parenthesized_assignment)?; - - return Ok(Statement::SetVariable { - local: modifier == Some(Keyword::LOCAL), - hivevar: modifier == Some(Keyword::HIVEVAR), - variables, - value: values, - }); + let stmt = match variables { + OneOrManyWithParens::One(var) => Set::SingleAssignment { + local: modifier == Some(Keyword::LOCAL), + hivevar: modifier == Some(Keyword::HIVEVAR), + variable: var, + values: self.parse_set_values(false)?, + }, + OneOrManyWithParens::Many(vars) => Set::ParenthesizedAssignments { + variables: vars, + values: self.parse_set_values(true)?, + }, + }; + + return Ok(stmt.into()); } if self.dialect.supports_set_stmt_without_operator() { @@ -11171,15 +11192,20 @@ impl<'a> Parser<'a> { _ => return self.expected("IO, PROFILE, TIME or XML", self.peek_token()), }; let value = self.parse_session_param_value()?; - Ok(Statement::SetSessionParam(SetSessionParamKind::Statistics( - SetSessionParamStatistics { topic, value }, - ))) + Ok( + Set::SetSessionParam(SetSessionParamKind::Statistics(SetSessionParamStatistics { + topic, + value, + })) + .into(), + ) } else if self.parse_keyword(Keyword::IDENTITY_INSERT) { let obj = self.parse_object_name(false)?; let value = self.parse_session_param_value()?; - Ok(Statement::SetSessionParam( - SetSessionParamKind::IdentityInsert(SetSessionParamIdentityInsert { obj, value }), + Ok(Set::SetSessionParam(SetSessionParamKind::IdentityInsert( + SetSessionParamIdentityInsert { obj, value }, )) + .into()) } else if self.parse_keyword(Keyword::OFFSETS) { let keywords = self.parse_comma_separated(|parser| { let next_token = parser.next_token(); @@ -11189,9 +11215,13 @@ impl<'a> Parser<'a> { } })?; let value = self.parse_session_param_value()?; - Ok(Statement::SetSessionParam(SetSessionParamKind::Offsets( - SetSessionParamOffsets { keywords, value }, - ))) + Ok( + Set::SetSessionParam(SetSessionParamKind::Offsets(SetSessionParamOffsets { + keywords, + value, + })) + .into(), + ) } else { let names = self.parse_comma_separated(|parser| { let next_token = parser.next_token(); @@ -11201,9 +11231,13 @@ impl<'a> Parser<'a> { } })?; let value = self.parse_expr()?.to_string(); - Ok(Statement::SetSessionParam(SetSessionParamKind::Generic( - SetSessionParamGeneric { names, value }, - ))) + Ok( + Set::SetSessionParam(SetSessionParamKind::Generic(SetSessionParamGeneric { + names, + value, + })) + .into(), + ) } } @@ -13277,11 +13311,11 @@ impl<'a> Parser<'a> { } /// Parse a `var = expr` assignment, used in an UPDATE statement - pub fn parse_assignment(&mut self) -> Result { + pub fn parse_assignment(&mut self) -> Result { let target = self.parse_assignment_target()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; - Ok(Assignment { target, value }) + Ok(UpdateAssignment { target, value }) } /// Parse the left-hand side of an assignment, used in an UPDATE statement diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 3037d4ae5..52e2ed552 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -1725,11 +1725,11 @@ fn parse_merge() { }); let update_action = MergeAction::Update { assignments: vec![ - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("a")])), value: Expr::value(number("1")), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("b")])), value: Expr::value(number("2")), }, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 6c9e6d481..174d1d87f 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -384,15 +384,15 @@ fn parse_update() { assert_eq!( assignments, vec![ - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["a".into()])), value: Expr::value(number("1")), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["b".into()])), value: Expr::value(number("2")), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["c".into()])), value: Expr::value(number("3")), }, @@ -441,7 +441,7 @@ fn parse_update_set_from() { relation: table_from_name(ObjectName::from(vec![Ident::new("t1")])), joins: vec![], }, - assignments: vec![Assignment { + assignments: vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("name")])), value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) }], @@ -553,7 +553,7 @@ fn parse_update_with_table_alias() { table ); assert_eq!( - vec![Assignment { + vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("u"), Ident::new("username") @@ -8528,11 +8528,11 @@ fn parse_set_transaction() { // TRANSACTION, so no need to duplicate the tests here. We just do a quick // sanity check. match verified_stmt("SET TRANSACTION READ ONLY, READ WRITE, ISOLATION LEVEL SERIALIZABLE") { - Statement::SetTransaction { + Statement::Set(Set::SetTransaction { modes, session, snapshot, - } => { + }) => { assert_eq!( modes, vec![ @@ -8551,20 +8551,17 @@ fn parse_set_transaction() { #[test] fn parse_set_variable() { match verified_stmt("SET SOMETHING = '1'") { - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local, hivevar, - variables, - value, - } => { + variable, + values, + }) => { assert!(!local); assert!(!hivevar); + assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()])); assert_eq!( - variables, - OneOrManyWithParens::One(ObjectName::from(vec!["SOMETHING".into()])) - ); - assert_eq!( - value, + values, vec![Expr::Value( (Value::SingleQuotedString("1".into())).with_empty_span() )] @@ -8576,24 +8573,17 @@ fn parse_set_variable() { let multi_variable_dialects = all_dialects_where(|d| d.supports_parenthesized_set_variables()); let sql = r#"SET (a, b, c) = (1, 2, 3)"#; match multi_variable_dialects.verified_stmt(sql) { - Statement::SetVariable { - local, - hivevar, - variables, - value, - } => { - assert!(!local); - assert!(!hivevar); + Statement::Set(Set::ParenthesizedAssignments { variables, values }) => { assert_eq!( variables, - OneOrManyWithParens::Many(vec![ + vec![ ObjectName::from(vec!["a".into()]), ObjectName::from(vec!["b".into()]), ObjectName::from(vec!["c".into()]), - ]) + ] ); assert_eq!( - value, + values, vec![ Expr::value(number("1")), Expr::value(number("2")), @@ -8653,20 +8643,17 @@ fn parse_set_variable() { #[test] fn parse_set_role_as_variable() { match verified_stmt("SET role = 'foobar'") { - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local, hivevar, - variables, - value, - } => { + variable, + values, + }) => { assert!(!local); assert!(!hivevar); + assert_eq!(variable, ObjectName::from(vec!["role".into()])); assert_eq!( - variables, - OneOrManyWithParens::One(ObjectName::from(vec!["role".into()])) - ); - assert_eq!( - value, + values, vec![Expr::Value( (Value::SingleQuotedString("foobar".into())).with_empty_span() )] @@ -8703,20 +8690,17 @@ fn parse_double_colon_cast_at_timezone() { #[test] fn parse_set_time_zone() { match verified_stmt("SET TIMEZONE = 'UTC'") { - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local, hivevar, - variables: variable, - value, - } => { + variable, + values, + }) => { assert!(!local); assert!(!hivevar); + assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()])); assert_eq!( - variable, - OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])) - ); - assert_eq!( - value, + values, vec![Expr::Value( (Value::SingleQuotedString("UTC".into())).with_empty_span() )] @@ -9439,7 +9423,7 @@ fn parse_merge() { }), action: MergeAction::Update { assignments: vec![ - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("dest"), Ident::new("F") @@ -9449,7 +9433,7 @@ fn parse_merge() { Ident::new("F"), ]), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("dest"), Ident::new("G") @@ -14647,9 +14631,20 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { let stmt = dialects.verified_stmt("SET @a = 1, b = 2"); match stmt { - Statement::SetVariables { variables, values } => { - assert_eq!(values.len(), 2); - assert_eq!(variables.len(), 2); + Statement::Set(Set::MultipleAssignments { assignments }) => { + assert_eq!( + assignments, + vec![ + SetAssignment { + name: ObjectName::from(vec!["@a".into()]), + value: Expr::value(number("1")) + }, + SetAssignment { + name: ObjectName::from(vec!["b".into()]), + value: Expr::value(number("2")) + } + ] + ); } _ => panic!("Expected SetVariable with 2 variables and 2 values"), }; diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 27fa4b405..56fe22a0d 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -22,9 +22,8 @@ use sqlparser::ast::{ ClusteredBy, CommentDef, CreateFunction, CreateFunctionBody, CreateFunctionUsing, CreateTable, - Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, - OneOrManyWithParens, OrderByExpr, OrderByOptions, SelectItem, Statement, TableFactor, - UnaryOperator, Use, Value, + Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr, + OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value, }; use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::parser::ParserError; @@ -369,20 +368,20 @@ fn from_cte() { fn set_statement_with_minus() { assert_eq!( hive().verified_stmt("SET hive.tez.java.opts = -Xmx4g"), - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![ + variable: ObjectName::from(vec![ Ident::new("hive"), Ident::new("tez"), Ident::new("java"), Ident::new("opts") - ])), - value: vec![Expr::UnaryOp { + ]), + values: vec![Expr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(Expr::Identifier(Ident::new("Xmx4g"))) }], - } + }) ); assert_eq!( diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 3f313af4f..386bd1788 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -1254,14 +1254,14 @@ fn parse_mssql_declare() { for_query: None }] }, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("@bar")])), - value: vec![Expr::Value( + variable: ObjectName::from(vec![Ident::new("@bar")]), + values: vec![Expr::Value( (Value::Number("2".parse().unwrap(), false)).with_empty_span() )], - }, + }), Statement::Query(Box::new(Query { with: None, limit: None, diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 8d89ce4eb..cfd3090b0 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -617,12 +617,12 @@ fn parse_set_variables() { mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); assert_eq!( mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"), - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: true, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec!["autocommit".into()])), - value: vec![Expr::value(number("1"))], - } + variable: ObjectName::from(vec!["autocommit".into()]), + values: vec![Expr::value(number("1"))], + }) ); } @@ -1870,31 +1870,31 @@ fn parse_insert_with_on_duplicate_update() { ); assert_eq!( Some(OnInsert::DuplicateKeyUpdate(vec![ - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "description".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("description"))]), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_create".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_create"))]), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_read".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_read"))]), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_update".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_update"))]), }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_delete".to_string() )])), @@ -2086,7 +2086,7 @@ fn parse_update_with_joins() { table ); assert_eq!( - vec![Assignment { + vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("o"), Ident::new("completed") @@ -2695,19 +2695,19 @@ fn parse_set_names() { let stmt = mysql_and_generic().verified_stmt("SET NAMES utf8mb4"); assert_eq!( stmt, - Statement::SetNames { + Statement::Set(Set::SetNames { charset_name: "utf8mb4".into(), collation_name: None, - } + }) ); let stmt = mysql_and_generic().verified_stmt("SET NAMES utf8mb4 COLLATE bogus"); assert_eq!( stmt, - Statement::SetNames { + Statement::Set(Set::SetNames { charset_name: "utf8mb4".into(), collation_name: Some("bogus".to_string()), - } + }) ); let stmt = mysql_and_generic() @@ -2715,14 +2715,14 @@ fn parse_set_names() { .unwrap(); assert_eq!( stmt, - vec![Statement::SetNames { + vec![Statement::Set(Set::SetNames { charset_name: "utf8mb4".into(), collation_name: Some("bogus".to_string()), - }] + })] ); let stmt = mysql_and_generic().verified_stmt("SET NAMES DEFAULT"); - assert_eq!(stmt, Statement::SetNamesDefault {}); + assert_eq!(stmt, Statement::Set(Set::SetNamesDefault {})); } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 659ed9b01..1d2096f1f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1432,81 +1432,77 @@ fn parse_set() { let stmt = pg_and_generic().verified_stmt("SET a = b"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("a")])), - value: vec![Expr::Identifier(Ident { + variable: ObjectName::from(vec![Ident::new("a")]), + values: vec![Expr::Identifier(Ident { value: "b".into(), quote_style: None, span: Span::empty(), })], - } + }) ); let stmt = pg_and_generic().verified_stmt("SET a = 'b'"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("a")])), - value: vec![Expr::Value( + variable: ObjectName::from(vec![Ident::new("a")]), + values: vec![Expr::Value( (Value::SingleQuotedString("b".into())).with_empty_span() )], - } + }) ); let stmt = pg_and_generic().verified_stmt("SET a = 0"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("a")])), - value: vec![Expr::value(number("0"))], - } + variable: ObjectName::from(vec![Ident::new("a")]), + values: vec![Expr::value(number("0"))], + }) ); let stmt = pg_and_generic().verified_stmt("SET a = DEFAULT"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("a")])), - value: vec![Expr::Identifier(Ident::new("DEFAULT"))], - } + variable: ObjectName::from(vec![Ident::new("a")]), + values: vec![Expr::Identifier(Ident::new("DEFAULT"))], + }) ); let stmt = pg_and_generic().verified_stmt("SET LOCAL a = b"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: true, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![Ident::new("a")])), - value: vec![Expr::Identifier("b".into())], - } + variable: ObjectName::from(vec![Ident::new("a")]), + values: vec![Expr::Identifier("b".into())], + }) ); let stmt = pg_and_generic().verified_stmt("SET a.b.c = b"); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![ - Ident::new("a"), - Ident::new("b"), - Ident::new("c") - ])), - value: vec![Expr::Identifier(Ident { + variable: ObjectName::from(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), + values: vec![Expr::Identifier(Ident { value: "b".into(), quote_style: None, span: Span::empty(), })], - } + }) ); let stmt = pg_and_generic().one_statement_parses_to( @@ -1515,18 +1511,18 @@ fn parse_set() { ); assert_eq!( stmt, - Statement::SetVariable { + Statement::Set(Set::SingleAssignment { local: false, hivevar: false, - variables: OneOrManyWithParens::One(ObjectName::from(vec![ + variable: ObjectName::from(vec![ Ident::new("hive"), Ident::new("tez"), Ident::new("auto"), Ident::new("reducer"), Ident::new("parallelism") - ])), - value: vec![Expr::Value((Value::Boolean(false)).with_empty_span())], - } + ]), + values: vec![Expr::Value((Value::Boolean(false)).with_empty_span())], + }) ); pg_and_generic().one_statement_parses_to("SET a TO b", "SET a = b"); @@ -1560,10 +1556,10 @@ fn parse_set_role() { let stmt = pg_and_generic().verified_stmt(query); assert_eq!( stmt, - Statement::SetRole { + Statement::Set(Set::SetRole { context_modifier: ContextModifier::Session, role_name: None, - } + }) ); assert_eq!(query, stmt.to_string()); @@ -1571,14 +1567,14 @@ fn parse_set_role() { let stmt = pg_and_generic().verified_stmt(query); assert_eq!( stmt, - Statement::SetRole { + Statement::Set(Set::SetRole { context_modifier: ContextModifier::Local, role_name: Some(Ident { value: "rolename".to_string(), quote_style: Some('\"'), span: Span::empty(), }), - } + }) ); assert_eq!(query, stmt.to_string()); @@ -1586,14 +1582,14 @@ fn parse_set_role() { let stmt = pg_and_generic().verified_stmt(query); assert_eq!( stmt, - Statement::SetRole { + Statement::Set(Set::SetRole { context_modifier: ContextModifier::None, role_name: Some(Ident { value: "rolename".to_string(), quote_style: Some('\''), span: Span::empty(), }), - } + }) ); assert_eq!(query, stmt.to_string()); } @@ -1812,7 +1808,7 @@ fn parse_pg_on_conflict() { assert_eq!(vec![Ident::from("did")], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![Assignment { + assignments: vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), @@ -1845,7 +1841,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![ - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ "dname".into() ])), @@ -1854,7 +1850,7 @@ fn parse_pg_on_conflict() { "dname".into() ]) }, - Assignment { + UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ "area".into() ])), @@ -1906,7 +1902,7 @@ fn parse_pg_on_conflict() { assert_eq!(vec![Ident::from("did")], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![Assignment { + assignments: vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), @@ -1953,7 +1949,7 @@ fn parse_pg_on_conflict() { ); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![Assignment { + assignments: vec![UpdateAssignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), @@ -2982,16 +2978,16 @@ fn test_transaction_statement() { let statement = pg().verified_stmt("SET TRANSACTION SNAPSHOT '000003A1-1'"); assert_eq!( statement, - Statement::SetTransaction { + Statement::Set(Set::SetTransaction { modes: vec![], snapshot: Some(Value::SingleQuotedString(String::from("000003A1-1"))), session: false - } + }) ); let statement = pg().verified_stmt("SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE, ISOLATION LEVEL SERIALIZABLE"); assert_eq!( statement, - Statement::SetTransaction { + Statement::Set(Set::SetTransaction { modes: vec![ TransactionMode::AccessMode(TransactionAccessMode::ReadOnly), TransactionMode::AccessMode(TransactionAccessMode::ReadWrite), @@ -2999,7 +2995,7 @@ fn test_transaction_statement() { ], snapshot: None, session: true - } + }) ); } @@ -5641,7 +5637,7 @@ fn parse_create_type_as_enum() { #[test] fn parse_set_time_zone_alias() { match pg().verified_stmt("SET TIME ZONE 'UTC'") { - Statement::SetTimeZone { local, value } => { + Statement::Set(Set::SetTimeZone { local, value }) => { assert!(!local); assert_eq!( value, diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 361c9b051..418fdfe58 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -469,7 +469,7 @@ fn parse_update_tuple_row_values() { sqlite().verified_stmt("UPDATE x SET (a, b) = (1, 2)"), Statement::Update { or: None, - assignments: vec![Assignment { + assignments: vec![UpdateAssignment { target: AssignmentTarget::Tuple(vec![ ObjectName::from(vec![Ident::new("a"),]), ObjectName::from(vec![Ident::new("b"),]), From d982eba07c8581efc806daa75953ec6246015f90 Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Wed, 12 Mar 2025 20:49:18 +0200 Subject: [PATCH 10/11] re-enable alias for all dialects --- src/ast/mod.rs | 1 + src/parser/mod.rs | 15 ++++++++------- tests/sqlparser_common.rs | 14 ++++++++++---- tests/sqlparser_postgres.rs | 14 -------------- 4 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index ab08961fb..b0fea9509 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2441,6 +2441,7 @@ pub enum Set { /// /// Note: this is a PostgreSQL-specific statements /// `SET TIME ZONE ` is an alias for `SET timezone TO ` in PostgreSQL + /// However, we allow it for all dialects. SetTimeZone { local: bool, value: Expr }, /// ```sql /// SET NAMES 'charset_name' [COLLATE 'collation_name'] diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 74bc845d2..b9d79bbf7 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4314,7 +4314,8 @@ impl<'a> Parser<'a> { } /// Run a parser method `f`, reverting back to the current position if unsuccessful. - /// Returns `None` if `f` returns an error + /// Returns `ParserError::RecursionLimitExceeded` if `f` returns a `RecursionLimitExceeded`. + /// Returns `Ok(None)` if `f` returns any other error. pub fn maybe_parse(&mut self, f: F) -> Result, ParserError> where F: FnMut(&mut Parser) -> Result, @@ -11043,15 +11044,15 @@ impl<'a> Parser<'a> { values: self.parse_set_values(false)?, } .into()); - } else if self.dialect.is::() { - // Special case for Postgres + } else { + // A shorthand alias for SET TIME ZONE that doesn't require + // the assignment operator. It's originally PostgreSQL specific, + // but we allow it for all the dialects return Ok(Set::SetTimeZone { local: modifier == Some(Keyword::LOCAL), value: self.parse_expr()?, } .into()); - } else { - return self.expected("assignment operator", self.peek_token()); } } else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) { if self.parse_keyword(Keyword::DEFAULT) { @@ -11096,8 +11097,8 @@ impl<'a> Parser<'a> { } if self.dialect.supports_comma_separated_set_assignments() { - if let Ok(assignments) = - self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment)) + if let Some(assignments) = self + .maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))? { return if assignments.len() > 1 { let assignments = assignments diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 174d1d87f..a6058e077 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14654,8 +14654,14 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { #[test] fn parse_set_time_zone_alias() { - // not sure what other dialects support this - all_dialects_but_pg() - .parse_sql_statements("SET TIME ZONE 'UTC'") - .unwrap_err(); + match all_dialects().verified_stmt("SET TIME ZONE 'UTC'") { + Statement::Set(Set::SetTimeZone { local, value }) => { + assert!(!local); + assert_eq!( + value, + Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span()) + ); + } + _ => unreachable!(), + } } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 1d2096f1f..4c8a337b7 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -5634,20 +5634,6 @@ fn parse_create_type_as_enum() { } } -#[test] -fn parse_set_time_zone_alias() { - match pg().verified_stmt("SET TIME ZONE 'UTC'") { - Statement::Set(Set::SetTimeZone { local, value }) => { - assert!(!local); - assert_eq!( - value, - Expr::Value((Value::SingleQuotedString("UTC".into())).with_empty_span()) - ); - } - _ => unreachable!(), - } -} - #[test] fn parse_alter_type() { struct TestCase { From aa15be26b4ae0a94b584c29f70bf33e92011d90c Mon Sep 17 00:00:00 2001 From: MohamedAbdeen21 Date: Wed, 12 Mar 2025 21:14:49 +0200 Subject: [PATCH 11/11] revert Assignment rename --- src/ast/dml.rs | 14 +++++++------- src/ast/mod.rs | 12 ++++++------ src/ast/spans.rs | 32 ++++++++++++++++---------------- src/parser/mod.rs | 4 ++-- tests/sqlparser_bigquery.rs | 4 ++-- tests/sqlparser_common.rs | 14 +++++++------- tests/sqlparser_mysql.rs | 12 ++++++------ tests/sqlparser_postgres.rs | 10 +++++----- tests/sqlparser_sqlite.rs | 2 +- 9 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/ast/dml.rs b/src/ast/dml.rs index bc3a9546d..ccea7fbcb 100644 --- a/src/ast/dml.rs +++ b/src/ast/dml.rs @@ -32,12 +32,12 @@ use sqlparser_derive::{Visit, VisitMut}; pub use super::ddl::{ColumnDef, TableConstraint}; use super::{ - display_comma_separated, display_separated, query::InputFormatClause, ClusteredBy, CommentDef, - Expr, FileFormat, FromTable, HiveDistributionStyle, HiveFormat, HiveIOFormat, HiveRowFormat, - Ident, IndexType, InsertAliases, MysqlInsertPriority, ObjectName, OnCommit, OnInsert, - OneOrManyWithParens, OrderByExpr, Query, RowAccessPolicy, SelectItem, Setting, SqlOption, - SqliteOnConflict, StorageSerializationPolicy, TableEngine, TableObject, TableWithJoins, Tag, - UpdateAssignment, WrappedCollection, + display_comma_separated, display_separated, query::InputFormatClause, Assignment, ClusteredBy, + CommentDef, Expr, FileFormat, FromTable, HiveDistributionStyle, HiveFormat, HiveIOFormat, + HiveRowFormat, Ident, IndexType, InsertAliases, MysqlInsertPriority, ObjectName, OnCommit, + OnInsert, OneOrManyWithParens, OrderByExpr, Query, RowAccessPolicy, SelectItem, Setting, + SqlOption, SqliteOnConflict, StorageSerializationPolicy, TableEngine, TableObject, + TableWithJoins, Tag, WrappedCollection, }; /// Index column type. @@ -544,7 +544,7 @@ pub struct Insert { pub source: Option>, /// MySQL `INSERT INTO ... SET` /// See: - pub assignments: Vec, + pub assignments: Vec, /// partitioned insert (Hive) pub partitioned: Option>, /// Columns defined after PARTITION diff --git a/src/ast/mod.rs b/src/ast/mod.rs index b0fea9509..8ab3fc0f1 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2708,7 +2708,7 @@ pub enum Statement { /// TABLE table: TableWithJoins, /// Column assignments - assignments: Vec, + assignments: Vec, /// Table which provide value to be set from: Option, /// WHERE @@ -5553,7 +5553,7 @@ pub enum MinMaxValue { #[non_exhaustive] pub enum OnInsert { /// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead) - DuplicateKeyUpdate(Vec), + DuplicateKeyUpdate(Vec), /// ON CONFLICT is a PostgreSQL and Sqlite extension OnConflict(OnConflict), } @@ -5593,7 +5593,7 @@ pub enum OnConflictAction { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct DoUpdate { /// Column assignments - pub assignments: Vec, + pub assignments: Vec, /// WHERE pub selection: Option, } @@ -6219,12 +6219,12 @@ impl fmt::Display for GrantObjects { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct UpdateAssignment { +pub struct Assignment { pub target: AssignmentTarget, pub value: Expr, } -impl fmt::Display for UpdateAssignment { +impl fmt::Display for Assignment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} = {}", self.target, self.value) } @@ -7553,7 +7553,7 @@ pub enum MergeAction { /// ```sql /// UPDATE SET quantity = T.quantity + S.quantity /// ``` - Update { assignments: Vec }, + Update { assignments: Vec }, /// A plain `DELETE` clause Delete, } diff --git a/src/ast/spans.rs b/src/ast/spans.rs index f74e615f6..da5ee2f02 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -22,20 +22,20 @@ use crate::tokenizer::Span; use super::{ dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, - AlterIndexOperation, AlterTableOperation, Array, AssignmentTarget, CloseCursor, ClusteredIndex, - ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, ConstraintCharacteristics, - CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, Delete, DoUpdate, - ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Function, - FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, - GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, - JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, MatchRecognizePattern, - Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, - OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, - ProjectionSelect, Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, - ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, - SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject, - TableOptionsClustered, TableWithJoins, UpdateAssignment, UpdateTableFromKind, Use, Value, - Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, + AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor, + ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy, + ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, + Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, + Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, + FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate, + InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView, + MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, + OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, + PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem, + ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, + Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, + TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, + Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -1244,9 +1244,9 @@ impl Spanned for DoUpdate { } } -impl Spanned for UpdateAssignment { +impl Spanned for Assignment { fn span(&self) -> Span { - let UpdateAssignment { target, value } = self; + let Assignment { target, value } = self; target.span().union(&value.span()) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b9d79bbf7..8791808e0 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -13312,11 +13312,11 @@ impl<'a> Parser<'a> { } /// Parse a `var = expr` assignment, used in an UPDATE statement - pub fn parse_assignment(&mut self) -> Result { + pub fn parse_assignment(&mut self) -> Result { let target = self.parse_assignment_target()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; - Ok(UpdateAssignment { target, value }) + Ok(Assignment { target, value }) } /// Parse the left-hand side of an assignment, used in an UPDATE statement diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 52e2ed552..3037d4ae5 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -1725,11 +1725,11 @@ fn parse_merge() { }); let update_action = MergeAction::Update { assignments: vec![ - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("a")])), value: Expr::value(number("1")), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("b")])), value: Expr::value(number("2")), }, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a6058e077..ef7d243d8 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -384,15 +384,15 @@ fn parse_update() { assert_eq!( assignments, vec![ - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["a".into()])), value: Expr::value(number("1")), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["b".into()])), value: Expr::value(number("2")), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec!["c".into()])), value: Expr::value(number("3")), }, @@ -441,7 +441,7 @@ fn parse_update_set_from() { relation: table_from_name(ObjectName::from(vec![Ident::new("t1")])), joins: vec![], }, - assignments: vec![UpdateAssignment { + assignments: vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new("name")])), value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) }], @@ -553,7 +553,7 @@ fn parse_update_with_table_alias() { table ); assert_eq!( - vec![UpdateAssignment { + vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("u"), Ident::new("username") @@ -9423,7 +9423,7 @@ fn parse_merge() { }), action: MergeAction::Update { assignments: vec![ - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("dest"), Ident::new("F") @@ -9433,7 +9433,7 @@ fn parse_merge() { Ident::new("F"), ]), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("dest"), Ident::new("G") diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index cfd3090b0..86edfa538 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1870,31 +1870,31 @@ fn parse_insert_with_on_duplicate_update() { ); assert_eq!( Some(OnInsert::DuplicateKeyUpdate(vec![ - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "description".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("description"))]), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_create".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_create"))]), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_read".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_read"))]), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_update".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_update"))]), }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![Ident::new( "perm_delete".to_string() )])), @@ -2086,7 +2086,7 @@ fn parse_update_with_joins() { table ); assert_eq!( - vec![UpdateAssignment { + vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ Ident::new("o"), Ident::new("completed") diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 4c8a337b7..a65c4fa38 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1808,7 +1808,7 @@ fn parse_pg_on_conflict() { assert_eq!(vec![Ident::from("did")], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![UpdateAssignment { + assignments: vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), @@ -1841,7 +1841,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![ - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ "dname".into() ])), @@ -1850,7 +1850,7 @@ fn parse_pg_on_conflict() { "dname".into() ]) }, - UpdateAssignment { + Assignment { target: AssignmentTarget::ColumnName(ObjectName::from(vec![ "area".into() ])), @@ -1902,7 +1902,7 @@ fn parse_pg_on_conflict() { assert_eq!(vec![Ident::from("did")], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![UpdateAssignment { + assignments: vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), @@ -1949,7 +1949,7 @@ fn parse_pg_on_conflict() { ); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { - assignments: vec![UpdateAssignment { + assignments: vec![Assignment { target: AssignmentTarget::ColumnName(ObjectName::from( vec!["dname".into()] )), diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 418fdfe58..361c9b051 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -469,7 +469,7 @@ fn parse_update_tuple_row_values() { sqlite().verified_stmt("UPDATE x SET (a, b) = (1, 2)"), Statement::Update { or: None, - assignments: vec![UpdateAssignment { + assignments: vec![Assignment { target: AssignmentTarget::Tuple(vec![ ObjectName::from(vec![Ident::new("a"),]), ObjectName::from(vec![Ident::new("b"),]),