Skip to content

Commit 7686722

Browse files
MohamedAbdeen21ayman-sigma
authored andcommitted
SET statements: scope modifier for multiple assignments (apache#1772)
1 parent 04419bb commit 7686722

File tree

8 files changed

+134
-95
lines changed

8 files changed

+134
-95
lines changed

src/ast/mod.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -2656,7 +2656,7 @@ pub enum Set {
26562656
/// SQL Standard-style
26572657
/// SET a = 1;
26582658
SingleAssignment {
2659-
scope: ContextModifier,
2659+
scope: Option<ContextModifier>,
26602660
hivevar: bool,
26612661
variable: ObjectName,
26622662
values: Vec<Expr>,
@@ -2686,7 +2686,7 @@ pub enum Set {
26862686
/// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm
26872687
SetRole {
26882688
/// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`).
2689-
context_modifier: ContextModifier,
2689+
context_modifier: Option<ContextModifier>,
26902690
/// Role name. If NONE is specified, then the current role name is removed.
26912691
role_name: Option<Ident>,
26922692
},
@@ -2738,7 +2738,13 @@ impl Display for Set {
27382738
role_name,
27392739
} => {
27402740
let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE"));
2741-
write!(f, "SET {context_modifier}ROLE {role_name}")
2741+
write!(
2742+
f,
2743+
"SET {modifier}ROLE {role_name}",
2744+
modifier = context_modifier
2745+
.map(|m| format!("{}", m))
2746+
.unwrap_or_default()
2747+
)
27422748
}
27432749
Self::SetSessionParam(kind) => write!(f, "SET {kind}"),
27442750
Self::SetTransaction {
@@ -2793,7 +2799,7 @@ impl Display for Set {
27932799
write!(
27942800
f,
27952801
"SET {}{}{} = {}",
2796-
scope,
2802+
scope.map(|s| format!("{}", s)).unwrap_or_default(),
27972803
if *hivevar { "HIVEVAR:" } else { "" },
27982804
variable,
27992805
display_comma_separated(values)
@@ -5754,13 +5760,20 @@ impl fmt::Display for SequenceOptions {
57545760
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57555761
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
57565762
pub struct SetAssignment {
5763+
pub scope: Option<ContextModifier>,
57575764
pub name: ObjectName,
57585765
pub value: Expr,
57595766
}
57605767

57615768
impl fmt::Display for SetAssignment {
57625769
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5763-
write!(f, "{} = {}", self.name, self.value)
5770+
write!(
5771+
f,
5772+
"{}{} = {}",
5773+
self.scope.map(|s| format!("{}", s)).unwrap_or_default(),
5774+
self.name,
5775+
self.value
5776+
)
57645777
}
57655778
}
57665779

@@ -7987,8 +8000,6 @@ impl fmt::Display for FlushLocation {
79878000
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79888001
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
79898002
pub enum ContextModifier {
7990-
/// No context defined. Each dialect defines the default in this scenario.
7991-
None,
79928003
/// `LOCAL` identifier, usually related to transactional states.
79938004
Local,
79948005
/// `SESSION` identifier
@@ -8000,9 +8011,6 @@ pub enum ContextModifier {
80008011
impl fmt::Display for ContextModifier {
80018012
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80028013
match self {
8003-
Self::None => {
8004-
write!(f, "")
8005-
}
80068014
Self::Local => {
80078015
write!(f, "LOCAL ")
80088016
}

src/dialect/generic.rs

+4
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,8 @@ impl Dialect for GenericDialect {
159159
fn supports_set_names(&self) -> bool {
160160
true
161161
}
162+
163+
fn supports_comma_separated_set_assignments(&self) -> bool {
164+
true
165+
}
162166
}

src/parser/mod.rs

+57-64
Original file line numberDiff line numberDiff line change
@@ -1822,12 +1822,12 @@ impl<'a> Parser<'a> {
18221822
})
18231823
}
18241824

1825-
fn keyword_to_modifier(k: Option<Keyword>) -> ContextModifier {
1825+
fn keyword_to_modifier(k: Keyword) -> Option<ContextModifier> {
18261826
match k {
1827-
Some(Keyword::LOCAL) => ContextModifier::Local,
1828-
Some(Keyword::GLOBAL) => ContextModifier::Global,
1829-
Some(Keyword::SESSION) => ContextModifier::Session,
1830-
_ => ContextModifier::None,
1827+
Keyword::LOCAL => Some(ContextModifier::Local),
1828+
Keyword::GLOBAL => Some(ContextModifier::Global),
1829+
Keyword::SESSION => Some(ContextModifier::Session),
1830+
_ => None,
18311831
}
18321832
}
18331833

@@ -11167,17 +11167,19 @@ impl<'a> Parser<'a> {
1116711167
}
1116811168

1116911169
/// Parse a `SET ROLE` statement. Expects SET to be consumed already.
11170-
fn parse_set_role(&mut self, modifier: Option<Keyword>) -> Result<Statement, ParserError> {
11170+
fn parse_set_role(
11171+
&mut self,
11172+
modifier: Option<ContextModifier>,
11173+
) -> Result<Statement, ParserError> {
1117111174
self.expect_keyword_is(Keyword::ROLE)?;
11172-
let context_modifier = Self::keyword_to_modifier(modifier);
1117311175

1117411176
let role_name = if self.parse_keyword(Keyword::NONE) {
1117511177
None
1117611178
} else {
1117711179
Some(self.parse_identifier()?)
1117811180
};
1117911181
Ok(Statement::Set(Set::SetRole {
11180-
context_modifier,
11182+
context_modifier: modifier,
1118111183
role_name,
1118211184
}))
1118311185
}
@@ -11213,46 +11215,52 @@ impl<'a> Parser<'a> {
1121311215
}
1121411216
}
1121511217

11216-
fn parse_set_assignment(
11217-
&mut self,
11218-
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> {
11219-
let variables = if self.dialect.supports_parenthesized_set_variables()
11218+
fn parse_context_modifier(&mut self) -> Option<ContextModifier> {
11219+
let modifier =
11220+
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::GLOBAL])?;
11221+
11222+
Self::keyword_to_modifier(modifier)
11223+
}
11224+
11225+
/// Parse a single SET statement assignment `var = expr`.
11226+
fn parse_set_assignment(&mut self) -> Result<SetAssignment, ParserError> {
11227+
let scope = self.parse_context_modifier();
11228+
11229+
let name = if self.dialect.supports_parenthesized_set_variables()
1122011230
&& self.consume_token(&Token::LParen)
1122111231
{
11222-
let vars = OneOrManyWithParens::Many(
11223-
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
11224-
.into_iter()
11225-
.map(|ident| ObjectName::from(vec![ident]))
11226-
.collect(),
11227-
);
11228-
self.expect_token(&Token::RParen)?;
11229-
vars
11232+
// Parenthesized assignments are handled in the `parse_set` function after
11233+
// trying to parse list of assignments using this function.
11234+
// If a dialect supports both, and we find a LParen, we early exit from this function.
11235+
self.expected("Unparenthesized assignment", self.peek_token())?
1123011236
} else {
11231-
OneOrManyWithParens::One(self.parse_object_name(false)?)
11237+
self.parse_object_name(false)?
1123211238
};
1123311239

1123411240
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
1123511241
return self.expected("assignment operator", self.peek_token());
1123611242
}
1123711243

11238-
let values = self.parse_expr()?;
11244+
let value = self.parse_expr()?;
1123911245

11240-
Ok((variables, values))
11246+
Ok(SetAssignment { scope, name, value })
1124111247
}
1124211248

1124311249
fn parse_set(&mut self) -> Result<Statement, ParserError> {
11244-
let modifier = self.parse_one_of_keywords(&[
11245-
Keyword::SESSION,
11246-
Keyword::LOCAL,
11247-
Keyword::HIVEVAR,
11248-
Keyword::GLOBAL,
11249-
]);
11250-
11251-
if let Some(Keyword::HIVEVAR) = modifier {
11250+
let hivevar = self.parse_keyword(Keyword::HIVEVAR);
11251+
11252+
// Modifier is either HIVEVAR: or a ContextModifier (LOCAL, SESSION, etc), not both
11253+
let scope = if !hivevar {
11254+
self.parse_context_modifier()
11255+
} else {
11256+
None
11257+
};
11258+
11259+
if hivevar {
1125211260
self.expect_token(&Token::Colon)?;
1125311261
}
1125411262

11255-
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? {
11263+
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(scope))? {
1125611264
return Ok(set_role_stmt);
1125711265
}
1125811266

@@ -11262,8 +11270,8 @@ impl<'a> Parser<'a> {
1126211270
{
1126311271
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
1126411272
return Ok(Set::SingleAssignment {
11265-
scope: Self::keyword_to_modifier(modifier),
11266-
hivevar: modifier == Some(Keyword::HIVEVAR),
11273+
scope,
11274+
hivevar,
1126711275
variable: ObjectName::from(vec!["TIMEZONE".into()]),
1126811276
values: self.parse_set_values(false)?,
1126911277
}
@@ -11273,7 +11281,7 @@ impl<'a> Parser<'a> {
1127311281
// the assignment operator. It's originally PostgreSQL specific,
1127411282
// but we allow it for all the dialects
1127511283
return Ok(Set::SetTimeZone {
11276-
local: modifier == Some(Keyword::LOCAL),
11284+
local: scope == Some(ContextModifier::Local),
1127711285
value: self.parse_expr()?,
1127811286
}
1127911287
.into());
@@ -11321,41 +11329,26 @@ impl<'a> Parser<'a> {
1132111329
}
1132211330

1132311331
if self.dialect.supports_comma_separated_set_assignments() {
11332+
if scope.is_some() {
11333+
self.prev_token();
11334+
}
11335+
1132411336
if let Some(assignments) = self
1132511337
.maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))?
1132611338
{
1132711339
return if assignments.len() > 1 {
11328-
let assignments = assignments
11329-
.into_iter()
11330-
.map(|(var, val)| match var {
11331-
OneOrManyWithParens::One(v) => Ok(SetAssignment {
11332-
name: v,
11333-
value: val,
11334-
}),
11335-
OneOrManyWithParens::Many(_) => {
11336-
self.expected("List of single identifiers", self.peek_token())
11337-
}
11338-
})
11339-
.collect::<Result<_, _>>()?;
11340-
1134111340
Ok(Set::MultipleAssignments { assignments }.into())
1134211341
} else {
11343-
let (vars, values): (Vec<_>, Vec<_>) = assignments.into_iter().unzip();
11344-
11345-
let variable = match vars.into_iter().next() {
11346-
Some(OneOrManyWithParens::One(v)) => Ok(v),
11347-
Some(OneOrManyWithParens::Many(_)) => self.expected(
11348-
"Single assignment or list of assignments",
11349-
self.peek_token(),
11350-
),
11351-
None => self.expected("At least one identifier", self.peek_token()),
11352-
}?;
11342+
let SetAssignment { scope, name, value } =
11343+
assignments.into_iter().next().ok_or_else(|| {
11344+
ParserError::ParserError("Expected at least one assignment".to_string())
11345+
})?;
1135311346

1135411347
Ok(Set::SingleAssignment {
11355-
scope: Self::keyword_to_modifier(modifier),
11356-
hivevar: modifier == Some(Keyword::HIVEVAR),
11357-
variable,
11358-
values,
11348+
scope,
11349+
hivevar,
11350+
variable: name,
11351+
values: vec![value],
1135911352
}
1136011353
.into())
1136111354
};
@@ -11380,8 +11373,8 @@ impl<'a> Parser<'a> {
1138011373
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
1138111374
let stmt = match variables {
1138211375
OneOrManyWithParens::One(var) => Set::SingleAssignment {
11383-
scope: Self::keyword_to_modifier(modifier),
11384-
hivevar: modifier == Some(Keyword::HIVEVAR),
11376+
scope,
11377+
hivevar,
1138511378
variable: var,
1138611379
values: self.parse_set_values(false)?,
1138711380
},

tests/sqlparser_common.rs

+39-4
Original file line numberDiff line numberDiff line change
@@ -8636,7 +8636,7 @@ fn parse_set_variable() {
86368636
variable,
86378637
values,
86388638
}) => {
8639-
assert_eq!(scope, ContextModifier::None);
8639+
assert_eq!(scope, None);
86408640
assert!(!hivevar);
86418641
assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()]));
86428642
assert_eq!(
@@ -8656,7 +8656,7 @@ fn parse_set_variable() {
86568656
variable,
86578657
values,
86588658
}) => {
8659-
assert_eq!(scope, ContextModifier::Global);
8659+
assert_eq!(scope, Some(ContextModifier::Global));
86608660
assert!(!hivevar);
86618661
assert_eq!(variable, ObjectName::from(vec!["VARIABLE".into()]));
86628662
assert_eq!(
@@ -8748,7 +8748,7 @@ fn parse_set_role_as_variable() {
87488748
variable,
87498749
values,
87508750
}) => {
8751-
assert_eq!(scope, ContextModifier::None);
8751+
assert_eq!(scope, None);
87528752
assert!(!hivevar);
87538753
assert_eq!(variable, ObjectName::from(vec!["role".into()]));
87548754
assert_eq!(
@@ -8795,7 +8795,7 @@ fn parse_set_time_zone() {
87958795
variable,
87968796
values,
87978797
}) => {
8798-
assert_eq!(scope, ContextModifier::None);
8798+
assert_eq!(scope, None);
87998799
assert!(!hivevar);
88008800
assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()]));
88018801
assert_eq!(
@@ -14862,10 +14862,12 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
1486214862
assignments,
1486314863
vec![
1486414864
SetAssignment {
14865+
scope: None,
1486514866
name: ObjectName::from(vec!["@a".into()]),
1486614867
value: Expr::value(number("1"))
1486714868
},
1486814869
SetAssignment {
14870+
scope: None,
1486914871
name: ObjectName::from(vec!["b".into()]),
1487014872
value: Expr::value(number("2"))
1487114873
}
@@ -14875,6 +14877,39 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
1487514877
_ => panic!("Expected SetVariable with 2 variables and 2 values"),
1487614878
};
1487714879

14880+
let stmt = dialects.verified_stmt("SET GLOBAL @a = 1, SESSION b = 2, LOCAL c = 3, d = 4");
14881+
14882+
match stmt {
14883+
Statement::Set(Set::MultipleAssignments { assignments }) => {
14884+
assert_eq!(
14885+
assignments,
14886+
vec![
14887+
SetAssignment {
14888+
scope: Some(ContextModifier::Global),
14889+
name: ObjectName::from(vec!["@a".into()]),
14890+
value: Expr::value(number("1"))
14891+
},
14892+
SetAssignment {
14893+
scope: Some(ContextModifier::Session),
14894+
name: ObjectName::from(vec!["b".into()]),
14895+
value: Expr::value(number("2"))
14896+
},
14897+
SetAssignment {
14898+
scope: Some(ContextModifier::Local),
14899+
name: ObjectName::from(vec!["c".into()]),
14900+
value: Expr::value(number("3"))
14901+
},
14902+
SetAssignment {
14903+
scope: None,
14904+
name: ObjectName::from(vec!["d".into()]),
14905+
value: Expr::value(number("4"))
14906+
}
14907+
]
14908+
);
14909+
}
14910+
_ => panic!("Expected MultipleAssignments with 4 scoped variables and 4 values"),
14911+
};
14912+
1487814913
Ok(())
1487914914
}
1488014915

0 commit comments

Comments
 (0)