Skip to content

Commit 3a8a3bb

Browse files
SET statements: scope modifier for multiple assignments (#1772)
1 parent 939fbdd commit 3a8a3bb

File tree

8 files changed

+134
-95
lines changed

8 files changed

+134
-95
lines changed

src/ast/mod.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -2638,7 +2638,7 @@ pub enum Set {
26382638
/// SQL Standard-style
26392639
/// SET a = 1;
26402640
SingleAssignment {
2641-
scope: ContextModifier,
2641+
scope: Option<ContextModifier>,
26422642
hivevar: bool,
26432643
variable: ObjectName,
26442644
values: Vec<Expr>,
@@ -2668,7 +2668,7 @@ pub enum Set {
26682668
/// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm
26692669
SetRole {
26702670
/// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`).
2671-
context_modifier: ContextModifier,
2671+
context_modifier: Option<ContextModifier>,
26722672
/// Role name. If NONE is specified, then the current role name is removed.
26732673
role_name: Option<Ident>,
26742674
},
@@ -2720,7 +2720,13 @@ impl Display for Set {
27202720
role_name,
27212721
} => {
27222722
let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE"));
2723-
write!(f, "SET {context_modifier}ROLE {role_name}")
2723+
write!(
2724+
f,
2725+
"SET {modifier}ROLE {role_name}",
2726+
modifier = context_modifier
2727+
.map(|m| format!("{}", m))
2728+
.unwrap_or_default()
2729+
)
27242730
}
27252731
Self::SetSessionParam(kind) => write!(f, "SET {kind}"),
27262732
Self::SetTransaction {
@@ -2775,7 +2781,7 @@ impl Display for Set {
27752781
write!(
27762782
f,
27772783
"SET {}{}{} = {}",
2778-
scope,
2784+
scope.map(|s| format!("{}", s)).unwrap_or_default(),
27792785
if *hivevar { "HIVEVAR:" } else { "" },
27802786
variable,
27812787
display_comma_separated(values)
@@ -5736,13 +5742,20 @@ impl fmt::Display for SequenceOptions {
57365742
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57375743
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
57385744
pub struct SetAssignment {
5745+
pub scope: Option<ContextModifier>,
57395746
pub name: ObjectName,
57405747
pub value: Expr,
57415748
}
57425749

57435750
impl fmt::Display for SetAssignment {
57445751
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5745-
write!(f, "{} = {}", self.name, self.value)
5752+
write!(
5753+
f,
5754+
"{}{} = {}",
5755+
self.scope.map(|s| format!("{}", s)).unwrap_or_default(),
5756+
self.name,
5757+
self.value
5758+
)
57465759
}
57475760
}
57485761

@@ -7969,8 +7982,6 @@ impl fmt::Display for FlushLocation {
79697982
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79707983
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
79717984
pub enum ContextModifier {
7972-
/// No context defined. Each dialect defines the default in this scenario.
7973-
None,
79747985
/// `LOCAL` identifier, usually related to transactional states.
79757986
Local,
79767987
/// `SESSION` identifier
@@ -7982,9 +7993,6 @@ pub enum ContextModifier {
79827993
impl fmt::Display for ContextModifier {
79837994
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79847995
match self {
7985-
Self::None => {
7986-
write!(f, "")
7987-
}
79887996
Self::Local => {
79897997
write!(f, "LOCAL ")
79907998
}

src/dialect/generic.rs

+4
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,8 @@ impl Dialect for GenericDialect {
159159
fn supports_set_names(&self) -> bool {
160160
true
161161
}
162+
163+
fn supports_comma_separated_set_assignments(&self) -> bool {
164+
true
165+
}
162166
}

src/parser/mod.rs

+57-64
Original file line numberDiff line numberDiff line change
@@ -1822,12 +1822,12 @@ impl<'a> Parser<'a> {
18221822
})
18231823
}
18241824

1825-
fn keyword_to_modifier(k: Option<Keyword>) -> ContextModifier {
1825+
fn keyword_to_modifier(k: Keyword) -> Option<ContextModifier> {
18261826
match k {
1827-
Some(Keyword::LOCAL) => ContextModifier::Local,
1828-
Some(Keyword::GLOBAL) => ContextModifier::Global,
1829-
Some(Keyword::SESSION) => ContextModifier::Session,
1830-
_ => ContextModifier::None,
1827+
Keyword::LOCAL => Some(ContextModifier::Local),
1828+
Keyword::GLOBAL => Some(ContextModifier::Global),
1829+
Keyword::SESSION => Some(ContextModifier::Session),
1830+
_ => None,
18311831
}
18321832
}
18331833

@@ -11157,17 +11157,19 @@ impl<'a> Parser<'a> {
1115711157
}
1115811158

1115911159
/// Parse a `SET ROLE` statement. Expects SET to be consumed already.
11160-
fn parse_set_role(&mut self, modifier: Option<Keyword>) -> Result<Statement, ParserError> {
11160+
fn parse_set_role(
11161+
&mut self,
11162+
modifier: Option<ContextModifier>,
11163+
) -> Result<Statement, ParserError> {
1116111164
self.expect_keyword_is(Keyword::ROLE)?;
11162-
let context_modifier = Self::keyword_to_modifier(modifier);
1116311165

1116411166
let role_name = if self.parse_keyword(Keyword::NONE) {
1116511167
None
1116611168
} else {
1116711169
Some(self.parse_identifier()?)
1116811170
};
1116911171
Ok(Statement::Set(Set::SetRole {
11170-
context_modifier,
11172+
context_modifier: modifier,
1117111173
role_name,
1117211174
}))
1117311175
}
@@ -11203,46 +11205,52 @@ impl<'a> Parser<'a> {
1120311205
}
1120411206
}
1120511207

11206-
fn parse_set_assignment(
11207-
&mut self,
11208-
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> {
11209-
let variables = if self.dialect.supports_parenthesized_set_variables()
11208+
fn parse_context_modifier(&mut self) -> Option<ContextModifier> {
11209+
let modifier =
11210+
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::GLOBAL])?;
11211+
11212+
Self::keyword_to_modifier(modifier)
11213+
}
11214+
11215+
/// Parse a single SET statement assignment `var = expr`.
11216+
fn parse_set_assignment(&mut self) -> Result<SetAssignment, ParserError> {
11217+
let scope = self.parse_context_modifier();
11218+
11219+
let name = if self.dialect.supports_parenthesized_set_variables()
1121011220
&& self.consume_token(&Token::LParen)
1121111221
{
11212-
let vars = OneOrManyWithParens::Many(
11213-
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
11214-
.into_iter()
11215-
.map(|ident| ObjectName::from(vec![ident]))
11216-
.collect(),
11217-
);
11218-
self.expect_token(&Token::RParen)?;
11219-
vars
11222+
// Parenthesized assignments are handled in the `parse_set` function after
11223+
// trying to parse list of assignments using this function.
11224+
// If a dialect supports both, and we find a LParen, we early exit from this function.
11225+
self.expected("Unparenthesized assignment", self.peek_token())?
1122011226
} else {
11221-
OneOrManyWithParens::One(self.parse_object_name(false)?)
11227+
self.parse_object_name(false)?
1122211228
};
1122311229

1122411230
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
1122511231
return self.expected("assignment operator", self.peek_token());
1122611232
}
1122711233

11228-
let values = self.parse_expr()?;
11234+
let value = self.parse_expr()?;
1122911235

11230-
Ok((variables, values))
11236+
Ok(SetAssignment { scope, name, value })
1123111237
}
1123211238

1123311239
fn parse_set(&mut self) -> Result<Statement, ParserError> {
11234-
let modifier = self.parse_one_of_keywords(&[
11235-
Keyword::SESSION,
11236-
Keyword::LOCAL,
11237-
Keyword::HIVEVAR,
11238-
Keyword::GLOBAL,
11239-
]);
11240-
11241-
if let Some(Keyword::HIVEVAR) = modifier {
11240+
let hivevar = self.parse_keyword(Keyword::HIVEVAR);
11241+
11242+
// Modifier is either HIVEVAR: or a ContextModifier (LOCAL, SESSION, etc), not both
11243+
let scope = if !hivevar {
11244+
self.parse_context_modifier()
11245+
} else {
11246+
None
11247+
};
11248+
11249+
if hivevar {
1124211250
self.expect_token(&Token::Colon)?;
1124311251
}
1124411252

11245-
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? {
11253+
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(scope))? {
1124611254
return Ok(set_role_stmt);
1124711255
}
1124811256

@@ -11252,8 +11260,8 @@ impl<'a> Parser<'a> {
1125211260
{
1125311261
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
1125411262
return Ok(Set::SingleAssignment {
11255-
scope: Self::keyword_to_modifier(modifier),
11256-
hivevar: modifier == Some(Keyword::HIVEVAR),
11263+
scope,
11264+
hivevar,
1125711265
variable: ObjectName::from(vec!["TIMEZONE".into()]),
1125811266
values: self.parse_set_values(false)?,
1125911267
}
@@ -11263,7 +11271,7 @@ impl<'a> Parser<'a> {
1126311271
// the assignment operator. It's originally PostgreSQL specific,
1126411272
// but we allow it for all the dialects
1126511273
return Ok(Set::SetTimeZone {
11266-
local: modifier == Some(Keyword::LOCAL),
11274+
local: scope == Some(ContextModifier::Local),
1126711275
value: self.parse_expr()?,
1126811276
}
1126911277
.into());
@@ -11311,41 +11319,26 @@ impl<'a> Parser<'a> {
1131111319
}
1131211320

1131311321
if self.dialect.supports_comma_separated_set_assignments() {
11322+
if scope.is_some() {
11323+
self.prev_token();
11324+
}
11325+
1131411326
if let Some(assignments) = self
1131511327
.maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))?
1131611328
{
1131711329
return if assignments.len() > 1 {
11318-
let assignments = assignments
11319-
.into_iter()
11320-
.map(|(var, val)| match var {
11321-
OneOrManyWithParens::One(v) => Ok(SetAssignment {
11322-
name: v,
11323-
value: val,
11324-
}),
11325-
OneOrManyWithParens::Many(_) => {
11326-
self.expected("List of single identifiers", self.peek_token())
11327-
}
11328-
})
11329-
.collect::<Result<_, _>>()?;
11330-
1133111330
Ok(Set::MultipleAssignments { assignments }.into())
1133211331
} else {
11333-
let (vars, values): (Vec<_>, Vec<_>) = assignments.into_iter().unzip();
11334-
11335-
let variable = match vars.into_iter().next() {
11336-
Some(OneOrManyWithParens::One(v)) => Ok(v),
11337-
Some(OneOrManyWithParens::Many(_)) => self.expected(
11338-
"Single assignment or list of assignments",
11339-
self.peek_token(),
11340-
),
11341-
None => self.expected("At least one identifier", self.peek_token()),
11342-
}?;
11332+
let SetAssignment { scope, name, value } =
11333+
assignments.into_iter().next().ok_or_else(|| {
11334+
ParserError::ParserError("Expected at least one assignment".to_string())
11335+
})?;
1134311336

1134411337
Ok(Set::SingleAssignment {
11345-
scope: Self::keyword_to_modifier(modifier),
11346-
hivevar: modifier == Some(Keyword::HIVEVAR),
11347-
variable,
11348-
values,
11338+
scope,
11339+
hivevar,
11340+
variable: name,
11341+
values: vec![value],
1134911342
}
1135011343
.into())
1135111344
};
@@ -11370,8 +11363,8 @@ impl<'a> Parser<'a> {
1137011363
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
1137111364
let stmt = match variables {
1137211365
OneOrManyWithParens::One(var) => Set::SingleAssignment {
11373-
scope: Self::keyword_to_modifier(modifier),
11374-
hivevar: modifier == Some(Keyword::HIVEVAR),
11366+
scope,
11367+
hivevar,
1137511368
variable: var,
1137611369
values: self.parse_set_values(false)?,
1137711370
},

tests/sqlparser_common.rs

+39-4
Original file line numberDiff line numberDiff line change
@@ -8635,7 +8635,7 @@ fn parse_set_variable() {
86358635
variable,
86368636
values,
86378637
}) => {
8638-
assert_eq!(scope, ContextModifier::None);
8638+
assert_eq!(scope, None);
86398639
assert!(!hivevar);
86408640
assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()]));
86418641
assert_eq!(
@@ -8655,7 +8655,7 @@ fn parse_set_variable() {
86558655
variable,
86568656
values,
86578657
}) => {
8658-
assert_eq!(scope, ContextModifier::Global);
8658+
assert_eq!(scope, Some(ContextModifier::Global));
86598659
assert!(!hivevar);
86608660
assert_eq!(variable, ObjectName::from(vec!["VARIABLE".into()]));
86618661
assert_eq!(
@@ -8747,7 +8747,7 @@ fn parse_set_role_as_variable() {
87478747
variable,
87488748
values,
87498749
}) => {
8750-
assert_eq!(scope, ContextModifier::None);
8750+
assert_eq!(scope, None);
87518751
assert!(!hivevar);
87528752
assert_eq!(variable, ObjectName::from(vec!["role".into()]));
87538753
assert_eq!(
@@ -8794,7 +8794,7 @@ fn parse_set_time_zone() {
87948794
variable,
87958795
values,
87968796
}) => {
8797-
assert_eq!(scope, ContextModifier::None);
8797+
assert_eq!(scope, None);
87988798
assert!(!hivevar);
87998799
assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()]));
88008800
assert_eq!(
@@ -14859,10 +14859,12 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
1485914859
assignments,
1486014860
vec![
1486114861
SetAssignment {
14862+
scope: None,
1486214863
name: ObjectName::from(vec!["@a".into()]),
1486314864
value: Expr::value(number("1"))
1486414865
},
1486514866
SetAssignment {
14867+
scope: None,
1486614868
name: ObjectName::from(vec!["b".into()]),
1486714869
value: Expr::value(number("2"))
1486814870
}
@@ -14872,6 +14874,39 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
1487214874
_ => panic!("Expected SetVariable with 2 variables and 2 values"),
1487314875
};
1487414876

14877+
let stmt = dialects.verified_stmt("SET GLOBAL @a = 1, SESSION b = 2, LOCAL c = 3, d = 4");
14878+
14879+
match stmt {
14880+
Statement::Set(Set::MultipleAssignments { assignments }) => {
14881+
assert_eq!(
14882+
assignments,
14883+
vec![
14884+
SetAssignment {
14885+
scope: Some(ContextModifier::Global),
14886+
name: ObjectName::from(vec!["@a".into()]),
14887+
value: Expr::value(number("1"))
14888+
},
14889+
SetAssignment {
14890+
scope: Some(ContextModifier::Session),
14891+
name: ObjectName::from(vec!["b".into()]),
14892+
value: Expr::value(number("2"))
14893+
},
14894+
SetAssignment {
14895+
scope: Some(ContextModifier::Local),
14896+
name: ObjectName::from(vec!["c".into()]),
14897+
value: Expr::value(number("3"))
14898+
},
14899+
SetAssignment {
14900+
scope: None,
14901+
name: ObjectName::from(vec!["d".into()]),
14902+
value: Expr::value(number("4"))
14903+
}
14904+
]
14905+
);
14906+
}
14907+
_ => panic!("Expected MultipleAssignments with 4 scoped variables and 4 values"),
14908+
};
14909+
1487514910
Ok(())
1487614911
}
1487714912

0 commit comments

Comments
 (0)