diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9f895ee64..6bf789559 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2629,7 +2629,7 @@ pub enum Set { /// SQL Standard-style /// SET a = 1; SingleAssignment { - scope: ContextModifier, + scope: Option, hivevar: bool, variable: ObjectName, values: Vec, @@ -2659,7 +2659,7 @@ pub enum Set { /// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm SetRole { /// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`). - context_modifier: ContextModifier, + context_modifier: Option, /// Role name. If NONE is specified, then the current role name is removed. role_name: Option, }, @@ -2711,7 +2711,13 @@ impl Display for Set { role_name, } => { let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE")); - write!(f, "SET {context_modifier}ROLE {role_name}") + write!( + f, + "SET {modifier}ROLE {role_name}", + modifier = context_modifier + .map(|m| format!("{}", m)) + .unwrap_or_default() + ) } Self::SetSessionParam(kind) => write!(f, "SET {kind}"), Self::SetTransaction { @@ -2766,7 +2772,7 @@ impl Display for Set { write!( f, "SET {}{}{} = {}", - scope, + scope.map(|s| format!("{}", s)).unwrap_or_default(), if *hivevar { "HIVEVAR:" } else { "" }, variable, display_comma_separated(values) @@ -5727,13 +5733,20 @@ impl fmt::Display for SequenceOptions { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct SetAssignment { + pub scope: Option, pub name: ObjectName, pub value: Expr, } impl fmt::Display for SetAssignment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} = {}", self.name, self.value) + write!( + f, + "{}{} = {}", + self.scope.map(|s| format!("{}", s)).unwrap_or_default(), + self.name, + self.value + ) } } @@ -7960,8 +7973,6 @@ impl fmt::Display for FlushLocation { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub enum ContextModifier { - /// No context defined. Each dialect defines the default in this scenario. - None, /// `LOCAL` identifier, usually related to transactional states. Local, /// `SESSION` identifier @@ -7973,9 +7984,6 @@ pub enum ContextModifier { impl fmt::Display for ContextModifier { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::None => { - write!(f, "") - } Self::Local => { write!(f, "LOCAL ") } diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index c13d5aa69..92cfca8fd 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -159,4 +159,8 @@ impl Dialect for GenericDialect { fn supports_set_names(&self) -> bool { true } + + fn supports_comma_separated_set_assignments(&self) -> bool { + true + } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index dcf7a4a8f..b25166053 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1819,12 +1819,12 @@ impl<'a> Parser<'a> { }) } - fn keyword_to_modifier(k: Option) -> ContextModifier { + fn keyword_to_modifier(k: Keyword) -> Option { match k { - Some(Keyword::LOCAL) => ContextModifier::Local, - Some(Keyword::GLOBAL) => ContextModifier::Global, - Some(Keyword::SESSION) => ContextModifier::Session, - _ => ContextModifier::None, + Keyword::LOCAL => Some(ContextModifier::Local), + Keyword::GLOBAL => Some(ContextModifier::Global), + Keyword::SESSION => Some(ContextModifier::Session), + _ => None, } } @@ -11145,9 +11145,11 @@ impl<'a> Parser<'a> { } /// Parse a `SET ROLE` statement. Expects SET to be consumed already. - fn parse_set_role(&mut self, modifier: Option) -> Result { + fn parse_set_role( + &mut self, + modifier: Option, + ) -> Result { self.expect_keyword_is(Keyword::ROLE)?; - let context_modifier = Self::keyword_to_modifier(modifier); let role_name = if self.parse_keyword(Keyword::NONE) { None @@ -11155,7 +11157,7 @@ impl<'a> Parser<'a> { Some(self.parse_identifier()?) }; Ok(Statement::Set(Set::SetRole { - context_modifier, + context_modifier: modifier, role_name, })) } @@ -11191,46 +11193,52 @@ impl<'a> Parser<'a> { } } - fn parse_set_assignment( - &mut self, - ) -> Result<(OneOrManyWithParens, Expr), ParserError> { - let variables = if self.dialect.supports_parenthesized_set_variables() + fn parse_context_modifier(&mut self) -> Option { + let modifier = + self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::GLOBAL])?; + + Self::keyword_to_modifier(modifier) + } + + /// Parse a single SET statement assignment `var = expr`. + fn parse_set_assignment(&mut self) -> Result { + let scope = self.parse_context_modifier(); + + let name = if self.dialect.supports_parenthesized_set_variables() && self.consume_token(&Token::LParen) { - let vars = OneOrManyWithParens::Many( - self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())? - .into_iter() - .map(|ident| ObjectName::from(vec![ident])) - .collect(), - ); - self.expect_token(&Token::RParen)?; - vars + // Parenthesized assignments are handled in the `parse_set` function after + // trying to parse list of assignments using this function. + // If a dialect supports both, and we find a LParen, we early exit from this function. + self.expected("Unparenthesized assignment", self.peek_token())? } else { - OneOrManyWithParens::One(self.parse_object_name(false)?) + self.parse_object_name(false)? }; if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) { return self.expected("assignment operator", self.peek_token()); } - let values = self.parse_expr()?; + let value = self.parse_expr()?; - Ok((variables, values)) + Ok(SetAssignment { scope, name, value }) } fn parse_set(&mut self) -> Result { - let modifier = self.parse_one_of_keywords(&[ - Keyword::SESSION, - Keyword::LOCAL, - Keyword::HIVEVAR, - Keyword::GLOBAL, - ]); - - if let Some(Keyword::HIVEVAR) = modifier { + let hivevar = self.parse_keyword(Keyword::HIVEVAR); + + // Modifier is either HIVEVAR: or a ContextModifier (LOCAL, SESSION, etc), not both + let scope = if !hivevar { + self.parse_context_modifier() + } else { + None + }; + + if hivevar { self.expect_token(&Token::Colon)?; } - if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? { + if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(scope))? { return Ok(set_role_stmt); } @@ -11240,8 +11248,8 @@ impl<'a> Parser<'a> { { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { return Ok(Set::SingleAssignment { - scope: Self::keyword_to_modifier(modifier), - hivevar: modifier == Some(Keyword::HIVEVAR), + scope, + hivevar, variable: ObjectName::from(vec!["TIMEZONE".into()]), values: self.parse_set_values(false)?, } @@ -11251,7 +11259,7 @@ impl<'a> Parser<'a> { // the assignment operator. It's originally PostgreSQL specific, // but we allow it for all the dialects return Ok(Set::SetTimeZone { - local: modifier == Some(Keyword::LOCAL), + local: scope == Some(ContextModifier::Local), value: self.parse_expr()?, } .into()); @@ -11299,41 +11307,26 @@ impl<'a> Parser<'a> { } if self.dialect.supports_comma_separated_set_assignments() { + if scope.is_some() { + self.prev_token(); + } + if let Some(assignments) = self .maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))? { return if assignments.len() > 1 { - let assignments = assignments - .into_iter() - .map(|(var, val)| match var { - OneOrManyWithParens::One(v) => Ok(SetAssignment { - name: v, - value: val, - }), - OneOrManyWithParens::Many(_) => { - self.expected("List of single identifiers", self.peek_token()) - } - }) - .collect::>()?; - Ok(Set::MultipleAssignments { assignments }.into()) } else { - let (vars, values): (Vec<_>, Vec<_>) = assignments.into_iter().unzip(); - - let variable = match vars.into_iter().next() { - Some(OneOrManyWithParens::One(v)) => Ok(v), - Some(OneOrManyWithParens::Many(_)) => self.expected( - "Single assignment or list of assignments", - self.peek_token(), - ), - None => self.expected("At least one identifier", self.peek_token()), - }?; + let SetAssignment { scope, name, value } = + assignments.into_iter().next().ok_or_else(|| { + ParserError::ParserError("Expected at least one assignment".to_string()) + })?; Ok(Set::SingleAssignment { - scope: Self::keyword_to_modifier(modifier), - hivevar: modifier == Some(Keyword::HIVEVAR), - variable, - values, + scope, + hivevar, + variable: name, + values: vec![value], } .into()) }; @@ -11358,8 +11351,8 @@ impl<'a> Parser<'a> { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { let stmt = match variables { OneOrManyWithParens::One(var) => Set::SingleAssignment { - scope: Self::keyword_to_modifier(modifier), - hivevar: modifier == Some(Keyword::HIVEVAR), + scope, + hivevar, variable: var, values: self.parse_set_values(false)?, }, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4ba8df7fb..17d4a42ae 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8632,7 +8632,7 @@ fn parse_set_variable() { variable, values, }) => { - assert_eq!(scope, ContextModifier::None); + assert_eq!(scope, None); assert!(!hivevar); assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()])); assert_eq!( @@ -8652,7 +8652,7 @@ fn parse_set_variable() { variable, values, }) => { - assert_eq!(scope, ContextModifier::Global); + assert_eq!(scope, Some(ContextModifier::Global)); assert!(!hivevar); assert_eq!(variable, ObjectName::from(vec!["VARIABLE".into()])); assert_eq!( @@ -8744,7 +8744,7 @@ fn parse_set_role_as_variable() { variable, values, }) => { - assert_eq!(scope, ContextModifier::None); + assert_eq!(scope, None); assert!(!hivevar); assert_eq!(variable, ObjectName::from(vec!["role".into()])); assert_eq!( @@ -8791,7 +8791,7 @@ fn parse_set_time_zone() { variable, values, }) => { - assert_eq!(scope, ContextModifier::None); + assert_eq!(scope, None); assert!(!hivevar); assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()])); assert_eq!( @@ -14856,10 +14856,12 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { assignments, vec![ SetAssignment { + scope: None, name: ObjectName::from(vec!["@a".into()]), value: Expr::value(number("1")) }, SetAssignment { + scope: None, name: ObjectName::from(vec!["b".into()]), value: Expr::value(number("2")) } @@ -14869,6 +14871,39 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> { _ => panic!("Expected SetVariable with 2 variables and 2 values"), }; + let stmt = dialects.verified_stmt("SET GLOBAL @a = 1, SESSION b = 2, LOCAL c = 3, d = 4"); + + match stmt { + Statement::Set(Set::MultipleAssignments { assignments }) => { + assert_eq!( + assignments, + vec![ + SetAssignment { + scope: Some(ContextModifier::Global), + name: ObjectName::from(vec!["@a".into()]), + value: Expr::value(number("1")) + }, + SetAssignment { + scope: Some(ContextModifier::Session), + name: ObjectName::from(vec!["b".into()]), + value: Expr::value(number("2")) + }, + SetAssignment { + scope: Some(ContextModifier::Local), + name: ObjectName::from(vec!["c".into()]), + value: Expr::value(number("3")) + }, + SetAssignment { + scope: None, + name: ObjectName::from(vec!["d".into()]), + value: Expr::value(number("4")) + } + ] + ); + } + _ => panic!("Expected MultipleAssignments with 4 scoped variables and 4 values"), + }; + Ok(()) } diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index a9549cb60..2af93db7d 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -21,10 +21,9 @@ //! is also tested (on the inputs it can handle). use sqlparser::ast::{ - ClusteredBy, CommentDef, ContextModifier, CreateFunction, CreateFunctionBody, - CreateFunctionUsing, CreateTable, Expr, Function, FunctionArgumentList, FunctionArguments, - Ident, ObjectName, OrderByExpr, OrderByOptions, SelectItem, Set, Statement, TableFactor, - UnaryOperator, Use, Value, + ClusteredBy, CommentDef, CreateFunction, CreateFunctionBody, CreateFunctionUsing, CreateTable, + Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr, + OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value, }; use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::parser::ParserError; @@ -370,7 +369,7 @@ fn set_statement_with_minus() { assert_eq!( hive().verified_stmt("SET hive.tez.java.opts = -Xmx4g"), Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![ Ident::new("hive"), diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index d4e5fa719..f305e7f78 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -1251,7 +1251,7 @@ fn parse_mssql_declare() { }] }, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("@bar")]), values: vec![Expr::Value( diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 884351491..580e41858 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -618,7 +618,7 @@ fn parse_set_variables() { assert_eq!( mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"), Statement::Set(Set::SingleAssignment { - scope: ContextModifier::Local, + scope: Some(ContextModifier::Local), hivevar: false, variable: ObjectName::from(vec!["autocommit".into()]), values: vec![Expr::value(number("1"))], diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index cf66af74e..a6d65ec75 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1432,7 +1432,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("a")]), values: vec![Expr::Identifier(Ident { @@ -1447,7 +1447,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("a")]), values: vec![Expr::Value( @@ -1460,7 +1460,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("a")]), values: vec![Expr::value(number("0"))], @@ -1471,7 +1471,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("a")]), values: vec![Expr::Identifier(Ident::new("DEFAULT"))], @@ -1482,7 +1482,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::Local, + scope: Some(ContextModifier::Local), hivevar: false, variable: ObjectName::from(vec![Ident::new("a")]), values: vec![Expr::Identifier("b".into())], @@ -1493,7 +1493,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), values: vec![Expr::Identifier(Ident { @@ -1511,7 +1511,7 @@ fn parse_set() { assert_eq!( stmt, Statement::Set(Set::SingleAssignment { - scope: ContextModifier::None, + scope: None, hivevar: false, variable: ObjectName::from(vec![ Ident::new("hive"), @@ -1555,7 +1555,7 @@ fn parse_set_role() { assert_eq!( stmt, Statement::Set(Set::SetRole { - context_modifier: ContextModifier::Session, + context_modifier: Some(ContextModifier::Session), role_name: None, }) ); @@ -1566,7 +1566,7 @@ fn parse_set_role() { assert_eq!( stmt, Statement::Set(Set::SetRole { - context_modifier: ContextModifier::Local, + context_modifier: Some(ContextModifier::Local), role_name: Some(Ident { value: "rolename".to_string(), quote_style: Some('\"'), @@ -1581,7 +1581,7 @@ fn parse_set_role() { assert_eq!( stmt, Statement::Set(Set::SetRole { - context_modifier: ContextModifier::None, + context_modifier: None, role_name: Some(Ident { value: "rolename".to_string(), quote_style: Some('\''),