Skip to content

Commit db22343

Browse files
authored
feat: support multiple set variables (#5)
* feat: support multiple set variables * removed unused code
1 parent 0b44c7e commit db22343

File tree

5 files changed

+233
-73
lines changed

5 files changed

+233
-73
lines changed

src/ast/mod.rs

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod value;
2121
#[cfg(not(feature = "std"))]
2222
use alloc::{
2323
boxed::Box,
24+
format,
2425
string::{String, ToString},
2526
vec::Vec,
2627
};
@@ -433,15 +434,14 @@ impl fmt::Display for WindowSpec {
433434
write!(f, "ORDER BY {}", display_comma_separated(&self.order_by))?;
434435
}
435436
if let Some(window_frame) = &self.window_frame {
437+
f.write_str(delim)?;
436438
if let Some(end_bound) = &window_frame.end_bound {
437-
f.write_str(delim)?;
438439
write!(
439440
f,
440441
"{} BETWEEN {} AND {}",
441442
window_frame.units, window_frame.start_bound, end_bound
442443
)?;
443444
} else {
444-
f.write_str(delim)?;
445445
write!(f, "{} {}", window_frame.units, window_frame.start_bound)?;
446446
}
447447
}
@@ -714,10 +714,7 @@ pub enum Statement {
714714
/// least MySQL and PostgreSQL. Not all MySQL-specific syntatic forms are
715715
/// supported yet.
716716
SetVariable {
717-
local: bool,
718-
hivevar: bool,
719-
variable: Ident,
720-
value: Vec<SetVariableValue>,
717+
key_values: Vec<SetVariableKeyValue>,
721718
},
722719
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
723720
///
@@ -1260,23 +1257,44 @@ impl fmt::Display for Statement {
12601257
if *cascade { " CASCADE" } else { "" },
12611258
if *purge { " PURGE" } else { "" }
12621259
),
1263-
Statement::SetVariable {
1264-
local,
1265-
variable,
1266-
hivevar,
1267-
value,
1268-
} => {
1260+
1261+
Statement::SetVariable { key_values } => {
12691262
f.write_str("SET ")?;
1270-
if *local {
1271-
f.write_str("LOCAL ")?;
1263+
1264+
if let Some(key_value) = key_values.get(0) {
1265+
if key_value.hivevar {
1266+
let values: Vec<String> = key_value
1267+
.value
1268+
.iter()
1269+
.map(|value| value.to_string())
1270+
.collect();
1271+
1272+
return write!(
1273+
f,
1274+
"HIVEVAR:{} = {}",
1275+
key_value.key,
1276+
display_comma_separated(&values)
1277+
);
1278+
}
12721279
}
1273-
write!(
1274-
f,
1275-
"{hivevar}{name} = {value}",
1276-
hivevar = if *hivevar { "HIVEVAR:" } else { "" },
1277-
name = variable,
1278-
value = display_comma_separated(value)
1279-
)
1280+
1281+
let formatted_key_values: Vec<String> = key_values
1282+
.iter()
1283+
.map(|key_value| {
1284+
format!(
1285+
"{}{}",
1286+
if key_value.local { "LOCAL " } else { "" },
1287+
key_value
1288+
.value
1289+
.iter()
1290+
.map(|value| format!("{} = {}", key_value.key, value.to_string()))
1291+
.collect::<Vec<String>>()
1292+
.join(", ")
1293+
)
1294+
})
1295+
.collect();
1296+
1297+
write!(f, "{}", display_comma_separated(&formatted_key_values))
12801298
}
12811299
Statement::SetNames {
12821300
charset_name,
@@ -1612,6 +1630,7 @@ pub enum HiveRowFormat {
16121630
DELIMITED,
16131631
}
16141632

1633+
#[allow(clippy::large_enum_variant)]
16151634
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16161635
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16171636
pub enum HiveIOFormat {
@@ -1726,6 +1745,15 @@ pub enum SetVariableValue {
17261745
Literal(Value),
17271746
}
17281747

1748+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1749+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1750+
pub struct SetVariableKeyValue {
1751+
pub key: Ident,
1752+
pub value: Vec<SetVariableValue>,
1753+
pub local: bool,
1754+
pub hivevar: bool,
1755+
}
1756+
17291757
impl fmt::Display for SetVariableValue {
17301758
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17311759
use SetVariableValue::*;

src/ast/value.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
// limitations under the License.
1212

1313
#[cfg(not(feature = "std"))]
14-
use alloc::string::String;
14+
use alloc::{
15+
boxed::Box,
16+
string::{String, ToString},
17+
};
1518
use core::fmt;
1619

1720
#[cfg(feature = "bigdecimal")]

src/parser.rs

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,12 +2546,47 @@ impl<'a> Parser<'a> {
25462546
Keyword::SESSION,
25472547
Keyword::LOCAL,
25482548
Keyword::HIVEVAR,
2549+
Keyword::TRANSACTION,
25492550
]);
2550-
if let Some(Keyword::HIVEVAR) = modifier {
2551-
self.expect_token(&Token::Colon)?;
2551+
2552+
match modifier {
2553+
Some(Keyword::GLOBAL) | Some(Keyword::SESSION) | Some(Keyword::TRANSACTION) => {
2554+
let global = if modifier == Some(Keyword::GLOBAL) {
2555+
Some(true)
2556+
} else if modifier == Some(Keyword::SESSION) {
2557+
Some(false)
2558+
} else {
2559+
None
2560+
};
2561+
2562+
if let Some(Keyword::TRANSACTION) = modifier {
2563+
return Ok(Statement::SetTransaction {
2564+
modes: self.parse_transaction_modes()?,
2565+
global,
2566+
});
2567+
}
2568+
2569+
let identifier = self.parse_identifier();
2570+
2571+
if identifier.is_ok()
2572+
&& identifier
2573+
.unwrap()
2574+
.value
2575+
.eq_ignore_ascii_case("TRANSACTION")
2576+
{
2577+
return Ok(Statement::SetTransaction {
2578+
modes: self.parse_transaction_modes()?,
2579+
global,
2580+
});
2581+
} else {
2582+
self.prev_token();
2583+
}
2584+
}
2585+
_ => (),
25522586
}
25532587

25542588
let variable = self.parse_identifier()?;
2589+
25552590
if variable.value.eq_ignore_ascii_case("NAMES") {
25562591
let charset_name = self.parse_literal_string()?;
25572592
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
@@ -2560,12 +2595,23 @@ impl<'a> Parser<'a> {
25602595
None
25612596
};
25622597

2563-
Ok(Statement::SetNames {
2598+
return Ok(Statement::SetNames {
25642599
charset_name,
25652600
collation_name,
2566-
})
2567-
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
2601+
});
2602+
} else {
2603+
self.prev_token();
2604+
}
2605+
2606+
if let Some(Keyword::HIVEVAR) = modifier {
2607+
self.expect_token(&Token::Colon)?;
2608+
2609+
let variable = self.parse_identifier()?;
2610+
2611+
self.expect_token(&Token::Eq)?;
2612+
25682613
let mut values = vec![];
2614+
25692615
loop {
25702616
let token = self.peek_token();
25712617
let value = match (self.parse_value(), token) {
@@ -2574,31 +2620,52 @@ impl<'a> Parser<'a> {
25742620
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
25752621
};
25762622
values.push(value);
2623+
25772624
if self.consume_token(&Token::Comma) {
25782625
continue;
25792626
}
2627+
25802628
return Ok(Statement::SetVariable {
2581-
local: modifier == Some(Keyword::LOCAL),
2582-
hivevar: Some(Keyword::HIVEVAR) == modifier,
2583-
variable,
2584-
value: values,
2629+
key_values: [SetVariableKeyValue {
2630+
key: variable,
2631+
value: values,
2632+
local: false,
2633+
hivevar: true,
2634+
}]
2635+
.to_vec(),
25852636
});
25862637
}
2587-
} else if variable.value.eq_ignore_ascii_case("TRANSACTION") {
2588-
let global = if modifier == Some(Keyword::GLOBAL) {
2589-
Some(true)
2590-
} else if modifier == Some(Keyword::SESSION) {
2591-
Some(false)
2592-
} else {
2593-
None
2594-
};
2638+
}
25952639

2596-
Ok(Statement::SetTransaction {
2597-
global,
2598-
modes: self.parse_transaction_modes()?,
2599-
})
2600-
} else {
2601-
self.expected("equals sign or TO", self.peek_token())
2640+
let mut key_values: Vec<SetVariableKeyValue> = vec![];
2641+
loop {
2642+
let variable = self.parse_identifier()?;
2643+
let mut values = vec![];
2644+
2645+
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
2646+
let token = self.peek_token();
2647+
let value = match (self.parse_value(), token) {
2648+
(Ok(value), _) => SetVariableValue::Literal(value),
2649+
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
2650+
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
2651+
};
2652+
values.push(value);
2653+
2654+
key_values.push(SetVariableKeyValue {
2655+
key: variable,
2656+
value: values,
2657+
local: modifier == Some(Keyword::LOCAL),
2658+
hivevar: false,
2659+
});
2660+
2661+
if self.consume_token(&Token::Comma) {
2662+
continue;
2663+
}
2664+
2665+
return Ok(Statement::SetVariable { key_values });
2666+
} else {
2667+
return self.expected("equals sign or TO", self.peek_token());
2668+
}
26022669
}
26032670
}
26042671

@@ -2806,14 +2873,13 @@ impl<'a> Parser<'a> {
28062873
// followed by some joins or (B) another level of nesting.
28072874
let mut table_and_joins = self.parse_table_and_joins()?;
28082875

2809-
if !table_and_joins.joins.is_empty() {
2876+
// (B): `table_and_joins` (what we found inside the parentheses)
2877+
// is a nested join `(foo JOIN bar)`, not followed by other joins.
2878+
let is_nested_join = matches!(&table_and_joins.relation, TableFactor::NestedJoin(_));
2879+
2880+
if !table_and_joins.joins.is_empty() || is_nested_join {
28102881
self.expect_token(&Token::RParen)?;
28112882
Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) // (A)
2812-
} else if let TableFactor::NestedJoin(_) = &table_and_joins.relation {
2813-
// (B): `table_and_joins` (what we found inside the parentheses)
2814-
// is a nested join `(foo JOIN bar)`, not followed by other joins.
2815-
self.expect_token(&Token::RParen)?;
2816-
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
28172883
} else if dialect_of!(self is SnowflakeDialect | GenericDialect) {
28182884
// Dialect-specific behavior: Snowflake diverges from the
28192885
// standard and from most of the other implementations by

tests/sqlparser_mysql.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,54 @@ fn parse_show_create() {
122122
}
123123
}
124124

125+
#[test]
126+
fn parse_set_transaction() {
127+
mysql_and_generic().verified_stmt("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE");
128+
}
129+
130+
#[test]
131+
fn parse_set_variables() {
132+
let stmt = mysql_and_generic().verified_stmt("SET autocommit = 1, sql_mode = 'test'");
133+
134+
assert_eq!(
135+
stmt,
136+
Statement::SetVariable {
137+
key_values: [
138+
SetVariableKeyValue {
139+
local: false,
140+
hivevar: false,
141+
key: "autocommit".into(),
142+
value: vec![SetVariableValue::Literal(number("1"))],
143+
},
144+
SetVariableKeyValue {
145+
local: false,
146+
hivevar: false,
147+
key: "sql_mode".into(),
148+
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
149+
"test".into()
150+
))],
151+
}
152+
]
153+
.to_vec()
154+
}
155+
);
156+
157+
let stmt = mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1");
158+
159+
assert_eq!(
160+
stmt,
161+
Statement::SetVariable {
162+
key_values: [SetVariableKeyValue {
163+
local: true,
164+
hivevar: false,
165+
key: "autocommit".into(),
166+
value: vec![SetVariableValue::Literal(number("1"))],
167+
},]
168+
.to_vec()
169+
}
170+
);
171+
}
172+
125173
#[test]
126174
fn parse_create_table_auto_increment() {
127175
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";

0 commit comments

Comments
 (0)