diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 0b5366abe..3357b09f3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1738,32 +1738,15 @@ impl fmt::Display for ShowStatementFilter { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum SetVariableValue { - Ident(Ident), - Literal(Value), -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct SetVariableKeyValue { pub key: Ident, - pub value: Vec, + pub value: Vec, pub local: bool, pub hivevar: bool, } -impl fmt::Display for SetVariableValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use SetVariableValue::*; - match self { - Ident(ident) => write!(f, "{}", ident), - Literal(literal) => write!(f, "{}", literal), - } - } -} - /// Sqlite specific syntax /// /// https://sqlite.org/lang_conflict.html diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 6581195b8..d6095262c 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -24,6 +24,7 @@ impl Dialect for MySqlDialect { || ('A'..='Z').contains(&ch) || ch == '_' || ch == '$' + || ch == '@' || ('\u{0080}'..='\u{ffff}').contains(&ch) } diff --git a/src/keywords.rs b/src/keywords.rs index 7235bbd4f..db6483199 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -292,6 +292,7 @@ define_keywords!( MONTH, MSCK, MULTISET, + NAMES, NATIONAL, NATURAL, NCHAR, diff --git a/src/parser.rs b/src/parser.rs index 819292cbf..c451205c9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2585,9 +2585,7 @@ impl<'a> Parser<'a> { _ => (), } - let variable = self.parse_identifier()?; - - if variable.value.eq_ignore_ascii_case("NAMES") { + if self.parse_one_of_keywords(&[Keyword::NAMES]).is_some() { let charset_name = self.parse_literal_string()?; let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { Some(self.parse_literal_string()?) @@ -2599,8 +2597,6 @@ impl<'a> Parser<'a> { charset_name, collation_name, }); - } else { - self.prev_token(); } if let Some(Keyword::HIVEVAR) = modifier { @@ -2613,12 +2609,12 @@ impl<'a> Parser<'a> { let mut values = vec![]; loop { - let token = self.peek_token(); - let value = match (self.parse_value(), token) { - (Ok(value), _) => SetVariableValue::Literal(value), - (Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()), - (Err(_), unexpected) => self.expected("variable value", unexpected)?, + let value = if let Ok(expr) = self.parse_expr() { + expr + } else { + self.expected("variable value", self.peek_token())? }; + values.push(value); if self.consume_token(&Token::Comma) { @@ -2643,12 +2639,12 @@ impl<'a> Parser<'a> { let mut values = vec![]; if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { - let token = self.peek_token(); - let value = match (self.parse_value(), token) { - (Ok(value), _) => SetVariableValue::Literal(value), - (Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()), - (Err(_), unexpected) => self.expected("variable value", unexpected)?, + let value = if let Ok(expr) = self.parse_expr() { + expr + } else { + self.expected("variable value", self.peek_token())? }; + values.push(value); key_values.push(SetVariableKeyValue { diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index d5dcc2f15..8ecd6747d 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -129,41 +129,37 @@ fn parse_set_transaction() { #[test] fn parse_set_variables() { - let stmt = mysql_and_generic().verified_stmt("SET autocommit = 1, sql_mode = 'test'"); - assert_eq!( - stmt, + mysql_and_generic().verified_stmt("SET autocommit = 1, sql_mode = 'test'"), Statement::SetVariable { key_values: [ SetVariableKeyValue { local: false, hivevar: false, key: "autocommit".into(), - value: vec![SetVariableValue::Literal(number("1"))], + value: vec![Expr::Value(Value::Number("1".into(), false))], }, SetVariableKeyValue { local: false, hivevar: false, key: "sql_mode".into(), - value: vec![SetVariableValue::Literal(Value::SingleQuotedString( - "test".into() - ))], + value: vec![Expr::Value(Value::SingleQuotedString("test".into()))], } ] .to_vec() } ); - let stmt = mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"); + mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); assert_eq!( - stmt, + mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"), Statement::SetVariable { key_values: [SetVariableKeyValue { local: true, hivevar: false, key: "autocommit".into(), - value: vec![SetVariableValue::Literal(number("1"))], + value: vec![Expr::Value(Value::Number("1".into(), false))], },] .to_vec() } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index db90ea6a2..b63a08545 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -368,7 +368,10 @@ fn parse_set() { Statement::SetVariable { key_values: [SetVariableKeyValue { key: "a".into(), - value: vec![SetVariableValue::Ident("b".into())], + value: vec![Expr::Identifier(Ident { + value: "b".into(), + quote_style: None + })], local: false, hivevar: false, }] @@ -384,9 +387,7 @@ fn parse_set() { local: false, hivevar: false, key: "a".into(), - value: vec![SetVariableValue::Literal(Value::SingleQuotedString( - "b".into() - ))], + value: vec![Expr::Value(Value::SingleQuotedString("b".into()))], }] .to_vec() } @@ -400,7 +401,7 @@ fn parse_set() { local: false, hivevar: false, key: "a".into(), - value: vec![SetVariableValue::Literal(number("0"))], + value: vec![Expr::Value(Value::Number("0".into(), false))], }] .to_vec() } @@ -414,7 +415,10 @@ fn parse_set() { local: false, hivevar: false, key: "a".into(), - value: vec![SetVariableValue::Ident("DEFAULT".into())], + value: vec![Expr::Identifier(Ident { + value: "DEFAULT".into(), + quote_style: None + })], }] .to_vec() } @@ -428,7 +432,10 @@ fn parse_set() { local: true, hivevar: false, key: "a".into(), - value: vec![SetVariableValue::Ident("b".into())], + value: vec![Expr::Identifier(Ident { + value: "b".into(), + quote_style: None + })], }] .to_vec() }