Skip to content

SET with a list of comma separated assignments #1757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 12, 2025
21 changes: 21 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,17 @@ pub enum Statement {
variables: OneOrManyWithParens<ObjectName>,
value: Vec<Expr>,
},

/// ```sql
/// SET <variable> = expression [, <variable> = expression]*;
/// ```
///
/// Note: this is a MySQL-specific statement.
/// Refer to [`Dialect.supports_comma_separated_set_assignments`]
SetVariables {
variables: Vec<ObjectName>,
values: Vec<Expr>,
},
/// ```sql
/// SET TIME ZONE <value>
/// ```
Expand Down Expand Up @@ -5334,6 +5345,16 @@ impl fmt::Display for Statement {
Statement::List(command) => write!(f, "LIST {command}"),
Statement::Remove(command) => write!(f, "REMOVE {command}"),
Statement::SetSessionParam(kind) => write!(f, "SET {kind}"),
Statement::SetVariables { variables, values } => {
write!(f, "SET ")?;
variables
.iter()
.zip(values.iter())
.map(|(var, val)| format!("{var} = {val}"))
.collect::<Vec<_>>()
.join(", ")
.fmt(f)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ impl Spanned for Statement {
Statement::RaisError { .. } => Span::empty(),
Statement::List(..) | Statement::Remove(..) => Span::empty(),
Statement::SetSessionParam { .. } => Span::empty(),
Statement::SetVariables { .. } => Span::empty(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialect supports multiple `SET` statements
/// in a single statement.
///
/// ```sql
/// SET variable = expression [, variable = expression];
/// ```
fn supports_comma_separated_set_assignments(&self) -> bool {
false
}

/// Returns true if the dialect supports an `EXCEPT` clause following a
/// wildcard in a select list.
///
Expand Down
1 change: 1 addition & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Dialect for MsSqlDialect {
fn supports_start_transaction_modifier(&self) -> bool {
true
}

fn supports_end_transaction_modifier(&self) -> bool {
true
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ impl Dialect for MySqlDialect {
fn supports_set_names(&self) -> bool {
true
}

fn supports_comma_separated_set_assignments(&self) -> bool {
true
}
}

/// `LOCK TABLES`
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ define_keywords!(
CHANNEL,
CHAR,
CHARACTER,
CHARACTERISTICS,
CHARACTERS,
CHARACTER_LENGTH,
CHARSET,
Expand Down Expand Up @@ -557,6 +558,7 @@ define_keywords!(
MULTISET,
MUTATION,
NAME,
NAMES,
NANOSECOND,
NANOSECONDS,
NATIONAL,
Expand Down
221 changes: 145 additions & 76 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10961,41 +10961,98 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);
if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
} else if let Some(set_role_stmt) =
self.maybe_parse(|parser| parser.parse_set_role(modifier))?
{
return Ok(set_role_stmt);
fn parse_set_values(
&mut self,
parenthesized_assignment: bool,
) -> Result<Vec<Expr>, ParserError> {
let mut values = vec![];

if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}

loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(values);
}
}

let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
fn parse_set_assignment(
&mut self,
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> {
let variables = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
let vars = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
vars
} else {
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

let names = matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES"));
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
return self.expected("assignment operator", self.peek_token());
}

let values = self.parse_expr()?;

if names && self.dialect.supports_set_names() {
Ok((variables, values))
}

fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier =
self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]);

if let Some(Keyword::HIVEVAR) = modifier {
self.expect_token(&Token::Colon)?;
}

if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? {
return Ok(set_role_stmt);
}

// Handle special cases first
if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE])
|| self.parse_keyword(Keyword::TIMEZONE)
{
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables: OneOrManyWithParens::One(ObjectName::from(vec!["TIMEZONE".into()])),
value: self.parse_set_values(false)?,
});
}

// Special case for Postgres
return Ok(Statement::SetTimeZone {
local: modifier == Some(Keyword::LOCAL),
value: self.parse_expr()?,
});
} else if self.dialect.supports_set_names() && self.parse_keyword(Keyword::NAMES) {
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
}

let charset_name = self.parse_identifier()?;
let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() {
Some(self.parse_literal_string()?)
Expand All @@ -11007,63 +11064,14 @@ impl<'a> Parser<'a> {
charset_name,
collation_name,
});
}

let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}

let mut values = vec![];
loop {
let value = if let Some(expr) = self.try_parse_expr_sub_query()? {
expr
} else if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
}

if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: Some(Keyword::HIVEVAR) == modifier,
variables,
value: values,
});
}
}

let OneOrManyWithParens::One(variable) = variables else {
return self.expected("set variable", self.peek_token());
};

if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
// for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value>
match self.parse_expr() {
Ok(expr) => Ok(Statement::SetTimeZone {
local: modifier == Some(Keyword::LOCAL),
value: expr,
}),
_ => self.expected("timezone value", self.peek_token())?,
}
} else if variable.to_string() == "CHARACTERISTICS" {
} else if self.parse_keyword(Keyword::CHARACTERISTICS) {
self.expect_keywords(&[Keyword::AS, Keyword::TRANSACTION])?;
Ok(Statement::SetTransaction {
return Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: true,
})
} else if variable.to_string() == "TRANSACTION" && modifier.is_none() {
});
} else if self.parse_keyword(Keyword::TRANSACTION) {
if self.parse_keyword(Keyword::SNAPSHOT) {
let snapshot_id = self.parse_value()?.value;
return Ok(Statement::SetTransaction {
Expand All @@ -11072,17 +11080,78 @@ impl<'a> Parser<'a> {
session: false,
});
}
Ok(Statement::SetTransaction {
return Ok(Statement::SetTransaction {
modes: self.parse_transaction_modes()?,
snapshot: None,
session: false,
})
} else if self.dialect.supports_set_stmt_without_operator() {
self.prev_token();
self.parse_set_session_params()
});
}

if self.dialect.supports_comma_separated_set_assignments() {
if let Ok(v) =
self.try_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))
{
let (vars, values): (Vec<_>, Vec<_>) = v.into_iter().unzip();

return if vars.len() > 1 {
let variables = vars
.into_iter()
.map(|v| match v {
OneOrManyWithParens::One(v) => Ok(v),
_ => self.expected("List of single identifiers", self.peek_token()),
})
.collect::<Result<_, _>>()?;

Ok(Statement::SetVariables { variables, values })
} else {
let variable = match vars.into_iter().next() {
Some(v) => Ok(v),
None => self.expected("At least one identifier", self.peek_token()),
}?;

Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables: variable,
value: values,
})
};
}
}

let variables = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let vars = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())?
.into_iter()
.map(|ident| ObjectName::from(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
vars
} else {
self.expected("equals sign or TO", self.peek_token())
OneOrManyWithParens::One(self.parse_object_name(false)?)
};

if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
let values = self.parse_set_values(parenthesized_assignment)?;

return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: modifier == Some(Keyword::HIVEVAR),
variables,
value: values,
});
}

if self.dialect.supports_set_stmt_without_operator() {
self.prev_token();
return self.parse_set_session_params();
};

self.expected("equals sign or TO", self.peek_token())
}

pub fn parse_set_session_params(&mut self) -> Result<Statement, ParserError> {
Expand Down
Loading