Skip to content

Commit b098e2d

Browse files
vasilev-alexMazterQyou
authored andcommitted
feat: support multiple set variables (#5)
* feat: support multiple set variables * removed unused code
1 parent 0b8d232 commit b098e2d

File tree

6 files changed

+249
-95
lines changed

6 files changed

+249
-95
lines changed

src/ast/mod.rs

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod value;
2020
#[cfg(not(feature = "std"))]
2121
use alloc::{
2222
boxed::Box,
23+
format,
2324
string::{String, ToString},
2425
vec::Vec,
2526
};
@@ -847,10 +848,7 @@ pub enum Statement {
847848
/// least MySQL and PostgreSQL. Not all MySQL-specific syntatic forms are
848849
/// supported yet.
849850
SetVariable {
850-
local: bool,
851-
hivevar: bool,
852-
variable: Ident,
853-
value: Vec<SetVariableValue>,
851+
key_values: Vec<SetVariableKeyValue>,
854852
},
855853
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
856854
///
@@ -1465,23 +1463,44 @@ impl fmt::Display for Statement {
14651463
if *cascade { " CASCADE" } else { "" },
14661464
if *purge { " PURGE" } else { "" }
14671465
),
1468-
Statement::SetVariable {
1469-
local,
1470-
variable,
1471-
hivevar,
1472-
value,
1473-
} => {
1466+
1467+
Statement::SetVariable { key_values } => {
14741468
f.write_str("SET ")?;
1475-
if *local {
1476-
f.write_str("LOCAL ")?;
1469+
1470+
if let Some(key_value) = key_values.get(0) {
1471+
if key_value.hivevar {
1472+
let values: Vec<String> = key_value
1473+
.value
1474+
.iter()
1475+
.map(|value| value.to_string())
1476+
.collect();
1477+
1478+
return write!(
1479+
f,
1480+
"HIVEVAR:{} = {}",
1481+
key_value.key,
1482+
display_comma_separated(&values)
1483+
);
1484+
}
14771485
}
1478-
write!(
1479-
f,
1480-
"{hivevar}{name} = {value}",
1481-
hivevar = if *hivevar { "HIVEVAR:" } else { "" },
1482-
name = variable,
1483-
value = display_comma_separated(value)
1484-
)
1486+
1487+
let formatted_key_values: Vec<String> = key_values
1488+
.iter()
1489+
.map(|key_value| {
1490+
format!(
1491+
"{}{}",
1492+
if key_value.local { "LOCAL " } else { "" },
1493+
key_value
1494+
.value
1495+
.iter()
1496+
.map(|value| format!("{} = {}", key_value.key, value.to_string()))
1497+
.collect::<Vec<String>>()
1498+
.join(", ")
1499+
)
1500+
})
1501+
.collect();
1502+
1503+
write!(f, "{}", display_comma_separated(&formatted_key_values))
14851504
}
14861505
Statement::SetNames {
14871506
charset_name,
@@ -2181,6 +2200,15 @@ pub enum SetVariableValue {
21812200
Literal(Value),
21822201
}
21832202

2203+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2204+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2205+
pub struct SetVariableKeyValue {
2206+
pub key: Ident,
2207+
pub value: Vec<SetVariableValue>,
2208+
pub local: bool,
2209+
pub hivevar: bool,
2210+
}
2211+
21842212
impl fmt::Display for SetVariableValue {
21852213
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21862214
use SetVariableValue::*;

src/ast/value.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
// limitations under the License.
1212

1313
#[cfg(not(feature = "std"))]
14-
use alloc::string::String;
14+
use alloc::{
15+
boxed::Box,
16+
string::{String, ToString},
17+
};
1518
use core::fmt;
1619

1720
#[cfg(feature = "bigdecimal")]

src/keywords.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ define_keywords!(
118118
CHANGE,
119119
CHAR,
120120
CHARACTER,
121+
CHARACTERISTICS,
121122
CHARACTER_LENGTH,
122123
CHARSET,
123124
CHAR_LENGTH,
@@ -319,6 +320,7 @@ define_keywords!(
319320
MONTH,
320321
MSCK,
321322
MULTISET,
323+
NAMES,
322324
NATIONAL,
323325
NATURAL,
324326
NCHAR,

src/parser.rs

Lines changed: 111 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3117,76 +3117,136 @@ impl<'a> Parser<'a> {
31173117
Keyword::SESSION,
31183118
Keyword::LOCAL,
31193119
Keyword::HIVEVAR,
3120+
Keyword::TRANSACTION,
31203121
]);
3121-
if let Some(Keyword::HIVEVAR) = modifier {
3122-
self.expect_token(&Token::Colon)?;
3123-
}
31243122

3125-
let global = match modifier {
3126-
Some(Keyword::GLOBAL) => Some(true),
3127-
Some(Keyword::SESSION) => Some(false),
3128-
_ => None,
3129-
};
3123+
match modifier {
3124+
Some(Keyword::GLOBAL) | Some(Keyword::SESSION) | Some(Keyword::TRANSACTION) => {
3125+
let global = if modifier == Some(Keyword::GLOBAL) {
3126+
Some(true)
3127+
} else if modifier == Some(Keyword::SESSION) {
3128+
Some(false)
3129+
} else {
3130+
None
3131+
};
31303132

3131-
let variable = self.parse_identifier()?;
3132-
if variable.value.eq_ignore_ascii_case("NAMES") {
3133+
if self.parse_keyword(Keyword::CHARACTERISTICS) {
3134+
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
3135+
return Ok(Statement::SetTransaction {
3136+
modes: self.parse_transaction_modes()?,
3137+
snapshot: None,
3138+
global,
3139+
characteristics_as: true,
3140+
});
3141+
}
3142+
3143+
if let Some(Keyword::TRANSACTION) = modifier {
3144+
return Ok(Statement::SetTransaction {
3145+
modes: self.parse_transaction_modes()?,
3146+
global,
3147+
snapshot: None,
3148+
characteristics_as: false,
3149+
});
3150+
}
3151+
3152+
let identifier = self.parse_identifier();
3153+
3154+
if identifier.is_ok()
3155+
&& identifier
3156+
.unwrap()
3157+
.value
3158+
.eq_ignore_ascii_case("TRANSACTION")
3159+
{
3160+
return Ok(Statement::SetTransaction {
3161+
modes: self.parse_transaction_modes()?,
3162+
global,
3163+
snapshot: None,
3164+
characteristics_as: false,
3165+
});
3166+
} else {
3167+
self.prev_token();
3168+
}
3169+
}
3170+
_ => (),
3171+
}
3172+
3173+
if self.parse_one_of_keywords(&[Keyword::NAMES]).is_some() {
31333174
let charset_name = self.parse_literal_string()?;
31343175
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
31353176
Some(self.parse_literal_string()?)
31363177
} else {
31373178
None
31383179
};
31393180

3140-
Ok(Statement::SetNames {
3181+
return Ok(Statement::SetNames {
31413182
charset_name,
31423183
collation_name,
3143-
})
3144-
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
3184+
});
3185+
}
3186+
3187+
if let Some(Keyword::HIVEVAR) = modifier {
3188+
self.expect_token(&Token::Colon)?;
3189+
3190+
let variable = self.parse_identifier()?;
3191+
3192+
self.expect_token(&Token::Eq)?;
3193+
31453194
let mut values = vec![];
3195+
31463196
loop {
3147-
let token = self.peek_token();
3148-
let value = match (self.parse_value(), token) {
3149-
(Ok(value), _) => SetVariableValue::Literal(value),
3150-
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
3151-
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
3197+
let value = if let Ok(value) = self.parse_value() {
3198+
SetVariableValue::Literal(value)
3199+
} else {
3200+
self.expected("variable value", self.peek_token())?
31523201
};
3202+
31533203
values.push(value);
3204+
31543205
if self.consume_token(&Token::Comma) {
31553206
continue;
31563207
}
3208+
31573209
return Ok(Statement::SetVariable {
3158-
local: modifier == Some(Keyword::LOCAL),
3159-
hivevar: Some(Keyword::HIVEVAR) == modifier,
3160-
variable,
3161-
value: values,
3210+
key_values: [SetVariableKeyValue {
3211+
key: variable,
3212+
value: values,
3213+
local: false,
3214+
hivevar: true,
3215+
}]
3216+
.to_vec(),
31623217
});
31633218
}
3164-
} else if variable.value.eq_ignore_ascii_case("CHARACTERISTICS") {
3165-
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
3166-
Ok(Statement::SetTransaction {
3167-
modes: self.parse_transaction_modes()?,
3168-
snapshot: None,
3169-
global,
3170-
characteristics_as: true,
3171-
})
3172-
} else if variable.value.eq_ignore_ascii_case("TRANSACTION") {
3173-
if self.parse_keyword(Keyword::SNAPSHOT) {
3174-
let snaphot_id = self.parse_value()?;
3175-
return Ok(Statement::SetTransaction {
3176-
modes: vec![],
3177-
snapshot: Some(snaphot_id),
3178-
global,
3179-
characteristics_as: false,
3219+
}
3220+
3221+
let mut key_values: Vec<SetVariableKeyValue> = vec![];
3222+
loop {
3223+
let variable = self.parse_identifier()?;
3224+
let mut values = vec![];
3225+
3226+
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
3227+
let value = if let Ok(value) = self.parse_value() {
3228+
SetVariableValue::Literal(value)
3229+
} else {
3230+
self.expected("variable value", self.peek_token())?
3231+
};
3232+
3233+
values.push(value);
3234+
3235+
key_values.push(SetVariableKeyValue {
3236+
key: variable,
3237+
value: values,
3238+
local: modifier == Some(Keyword::LOCAL),
3239+
hivevar: false,
31803240
});
3241+
3242+
if self.consume_token(&Token::Comma) {
3243+
continue;
3244+
}
3245+
3246+
return Ok(Statement::SetVariable { key_values });
3247+
} else {
3248+
return self.expected("equals sign or TO", self.peek_token());
31813249
}
3182-
Ok(Statement::SetTransaction {
3183-
modes: self.parse_transaction_modes()?,
3184-
snapshot: None,
3185-
global,
3186-
characteristics_as: false,
3187-
})
3188-
} else {
3189-
self.expected("equals sign or TO", self.peek_token())
31903250
}
31913251
}
31923252

@@ -3396,15 +3456,13 @@ impl<'a> Parser<'a> {
33963456
// followed by some joins or (B) another level of nesting.
33973457
let mut table_and_joins = self.parse_table_and_joins()?;
33983458

3399-
#[allow(clippy::if_same_then_else)]
3400-
if !table_and_joins.joins.is_empty() {
3459+
// (B): `table_and_joins` (what we found inside the parentheses)
3460+
// is a nested join `(foo JOIN bar)`, not followed by other joins.
3461+
let is_nested_join = matches!(&table_and_joins.relation, TableFactor::NestedJoin(_));
3462+
3463+
if !table_and_joins.joins.is_empty() || is_nested_join {
34013464
self.expect_token(&Token::RParen)?;
34023465
Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) // (A)
3403-
} else if let TableFactor::NestedJoin(_) = &table_and_joins.relation {
3404-
// (B): `table_and_joins` (what we found inside the parentheses)
3405-
// is a nested join `(foo JOIN bar)`, not followed by other joins.
3406-
self.expect_token(&Token::RParen)?;
3407-
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
34083466
} else if dialect_of!(self is SnowflakeDialect | GenericDialect) {
34093467
// Dialect-specific behavior: Snowflake diverges from the
34103468
// standard and from most of the other implementations by

tests/sqlparser_mysql.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,54 @@ fn parse_show_create() {
124124
}
125125
}
126126

127+
#[test]
128+
fn parse_set_transaction() {
129+
mysql_and_generic().verified_stmt("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE");
130+
}
131+
132+
#[test]
133+
fn parse_set_variables() {
134+
let stmt = mysql_and_generic().verified_stmt("SET autocommit = 1, sql_mode = 'test'");
135+
136+
assert_eq!(
137+
stmt,
138+
Statement::SetVariable {
139+
key_values: [
140+
SetVariableKeyValue {
141+
local: false,
142+
hivevar: false,
143+
key: "autocommit".into(),
144+
value: vec![SetVariableValue::Literal(number("1"))],
145+
},
146+
SetVariableKeyValue {
147+
local: false,
148+
hivevar: false,
149+
key: "sql_mode".into(),
150+
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
151+
"test".into()
152+
))],
153+
}
154+
]
155+
.to_vec()
156+
}
157+
);
158+
159+
let stmt = mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1");
160+
161+
assert_eq!(
162+
stmt,
163+
Statement::SetVariable {
164+
key_values: [SetVariableKeyValue {
165+
local: true,
166+
hivevar: false,
167+
key: "autocommit".into(),
168+
value: vec![SetVariableValue::Literal(number("1"))],
169+
},]
170+
.to_vec()
171+
}
172+
);
173+
}
174+
127175
#[test]
128176
fn parse_create_table_auto_increment() {
129177
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";

0 commit comments

Comments
 (0)