Skip to content

Commit 51a2627

Browse files
committed
1 parent f500a42 commit 51a2627

File tree

5 files changed

+95
-0
lines changed

5 files changed

+95
-0
lines changed

src/ast/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ pub enum Expr {
235235
Subquery(Box<Query>),
236236
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
237237
ListAgg(ListAgg),
238+
/// Embed variable inside a query is supported by some databases:
239+
/// - Mysql: https://dev.mysql.com/doc/refman/8.0/en/user-variables.html
240+
/// - Snowflake: https://docs.snowflake.com/en/sql-reference/session-variables.html
241+
SqlVariable { prefix: char, name: Ident },
238242
}
239243

240244
impl fmt::Display for Expr {
@@ -315,6 +319,7 @@ impl fmt::Display for Expr {
315319
Expr::Exists(s) => write!(f, "EXISTS ({})", s),
316320
Expr::Subquery(s) => write!(f, "({})", s),
317321
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
322+
Expr::SqlVariable { prefix, name } => write!(f, "{}{}", prefix, name),
318323
}
319324
}
320325
}

src/parser.rs

+10
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,16 @@ impl<'a> Parser<'a> {
312312
self.expect_token(&Token::RParen)?;
313313
Ok(expr)
314314
}
315+
Token::Dollar if dialect_of!(self is SnowflakeDialect) => {
316+
// Snowflake user defined variables starts with $
317+
let name = self.parse_identifier()?;
318+
Ok(Expr::SqlVariable { prefix: '$', name })
319+
}
320+
Token::At if dialect_of!(self is MySqlDialect) => {
321+
// Mysql user defined variables starts with @
322+
let name = self.parse_identifier()?;
323+
Ok(Expr::SqlVariable { prefix: '@', name })
324+
}
315325
unexpected => self.expected("an expression", unexpected),
316326
}?;
317327

src/tokenizer.rs

+8
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ pub enum Token {
101101
RBrace,
102102
/// Right Arrow `=>`
103103
RArrow,
104+
/// Dollar sign `$`
105+
Dollar,
106+
/// At sign `@`
107+
At,
104108
}
105109

106110
impl fmt::Display for Token {
@@ -142,6 +146,8 @@ impl fmt::Display for Token {
142146
Token::LBrace => f.write_str("{"),
143147
Token::RBrace => f.write_str("}"),
144148
Token::RArrow => f.write_str("=>"),
149+
Token::Dollar => f.write_str("$"),
150+
Token::At => f.write_str("@"),
145151
}
146152
}
147153
}
@@ -448,6 +454,8 @@ impl<'a> Tokenizer<'a> {
448454
'^' => self.consume_and_return(chars, Token::Caret),
449455
'{' => self.consume_and_return(chars, Token::LBrace),
450456
'}' => self.consume_and_return(chars, Token::RBrace),
457+
'$' => self.consume_and_return(chars, Token::Dollar),
458+
'@' => self.consume_and_return(chars, Token::At),
451459
other => self.consume_and_return(chars, Token::Char(other)),
452460
},
453461
None => Ok(None),

tests/sqlparser_mysql.rs

+33
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ fn parse_quote_identifiers() {
152152
}
153153
}
154154

155+
#[test]
156+
fn test_query_with_variable_name() {
157+
let sql = "SELECT @var1";
158+
let select = mysql().verified_only_select(sql);
159+
160+
assert_eq!(
161+
only(select.projection),
162+
SelectItem::UnnamedExpr(Expr::SqlVariable {
163+
prefix: '@',
164+
name: Ident::new("var1")
165+
},)
166+
);
167+
168+
let sql = "SELECT c1 FROM t1 WHERE num BETWEEN @min AND @max";
169+
let select = mysql().verified_only_select(sql);
170+
171+
assert_eq!(
172+
select.selection.unwrap(),
173+
Expr::Between {
174+
expr: Box::new(Expr::Identifier("num".into())),
175+
low: Box::new(Expr::SqlVariable {
176+
prefix: '@',
177+
name: Ident::new("min")
178+
}),
179+
high: Box::new(Expr::SqlVariable {
180+
prefix: '@',
181+
name: Ident::new("max")
182+
}),
183+
negated: false,
184+
}
185+
);
186+
}
187+
155188
fn mysql() -> TestedDialects {
156189
TestedDialects {
157190
dialects: vec![Box::new(MySqlDialect {})],

tests/sqlparser_snowflake.rs

+39
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,45 @@ fn test_snowflake_create_table() {
2424
}
2525
}
2626

27+
#[test]
28+
fn test_query_with_variable_name() {
29+
let sql = "SELECT $var1";
30+
let select = snowflake().verified_only_select(sql);
31+
32+
assert_eq!(
33+
only(select.projection),
34+
SelectItem::UnnamedExpr(Expr::SqlVariable {
35+
prefix: '$',
36+
name: Ident::new("var1")
37+
},)
38+
);
39+
40+
let sql = "SELECT c1 FROM t1 WHERE num BETWEEN $min AND $max";
41+
let select = snowflake().verified_only_select(sql);
42+
43+
assert_eq!(
44+
select.selection.unwrap(),
45+
Expr::Between {
46+
expr: Box::new(Expr::Identifier("num".into())),
47+
low: Box::new(Expr::SqlVariable {
48+
prefix: '$',
49+
name: Ident::new("min")
50+
}),
51+
high: Box::new(Expr::SqlVariable {
52+
prefix: '$',
53+
name: Ident::new("max")
54+
}),
55+
negated: false,
56+
}
57+
);
58+
}
59+
60+
fn snowflake() -> TestedDialects {
61+
TestedDialects {
62+
dialects: vec![Box::new(SnowflakeDialect {})],
63+
}
64+
}
65+
2766
fn snowflake_and_generic() -> TestedDialects {
2867
TestedDialects {
2968
dialects: vec![Box::new(SnowflakeDialect {}), Box::new(GenericDialect {})],

0 commit comments

Comments
 (0)