Skip to content

Commit 616ab19

Browse files
committed
snowflake: Fix handling of @~% in the stage name
1 parent 83cb734 commit 616ab19

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

src/dialect/snowflake.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub struct SnowflakeDialect;
3434
impl Dialect for SnowflakeDialect {
3535
// see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html
3636
fn is_identifier_start(&self, ch: char) -> bool {
37-
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' || ch == '@' || ch == '%'
37+
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
3838
}
3939

4040
fn is_identifier_part(&self, ch: char) -> bool {
@@ -43,8 +43,6 @@ impl Dialect for SnowflakeDialect {
4343
|| ch.is_ascii_digit()
4444
|| ch == '$'
4545
|| ch == '_'
46-
|| ch == '/'
47-
|| ch == '~'
4846
}
4947

5048
fn supports_within_after_array_aggregation(&self) -> bool {

src/tokenizer.rs

+25
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,18 @@ impl<'a> Tokenizer<'a> {
10041004
}
10051005
}
10061006
Some(' ') => Ok(Some(Token::AtSign)),
1007+
// Snowflake stage identifier, this should be consumed as multiple dot separated word tokens
1008+
Some(_) if dialect_of!(self is SnowflakeDialect) => {
1009+
let mut s = "@".to_string();
1010+
s.push_str(&peeking_take_while(chars, |ch| {
1011+
self.dialect.is_identifier_part(ch)
1012+
|| ch == '/'
1013+
|| ch == '~'
1014+
|| ch == '%'
1015+
|| ch == '.'
1016+
}));
1017+
Ok(Some(Token::make_word(&s, None)))
1018+
}
10071019
Some(sch) if self.dialect.is_identifier_start('@') => {
10081020
self.tokenize_identifier_or_keyword([ch, *sch], chars)
10091021
}
@@ -2001,6 +2013,19 @@ mod tests {
20012013
compare(expected, tokens);
20022014
}
20032015

2016+
#[test]
2017+
fn tokenize_snowflake_div() {
2018+
let sql = r#"field/1000"#;
2019+
let dialect = SnowflakeDialect {};
2020+
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
2021+
let expected = vec![
2022+
Token::make_word(r#"field"#, None),
2023+
Token::Div,
2024+
Token::Number("1000".to_string(), false),
2025+
];
2026+
compare(expected, tokens);
2027+
}
2028+
20042029
#[test]
20052030
fn tokenize_quoted_identifier_with_no_escape() {
20062031
let sql = r#" "a "" b" "a """ "c """"" "#;

tests/sqlparser_snowflake.rs

+20-10
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ use test_utils::*;
2626
#[macro_use]
2727
mod test_utils;
2828

29+
#[cfg(test)]
30+
use pretty_assertions::assert_eq;
31+
2932
#[test]
3033
fn test_snowflake_create_table() {
3134
let sql = "CREATE TABLE _my_$table (am00unt number)";
@@ -903,7 +906,7 @@ fn test_copy_into_with_transformations() {
903906
} => {
904907
assert_eq!(
905908
from_stage,
906-
ObjectName(vec![Ident::new("@schema"), Ident::new("general_finished")])
909+
ObjectName(vec![Ident::new("@schema.general_finished")])
907910
);
908911
assert_eq!(
909912
from_transformations.as_ref().unwrap()[0],
@@ -1010,15 +1013,9 @@ fn test_snowflake_stage_object_names() {
10101013
];
10111014
let mut allowed_object_names = vec![
10121015
ObjectName(vec![Ident::new("my_company"), Ident::new("emp_basic")]),
1013-
ObjectName(vec![Ident::new("@namespace"), Ident::new("%table_name")]),
1014-
ObjectName(vec![
1015-
Ident::new("@namespace"),
1016-
Ident::new("%table_name/path"),
1017-
]),
1018-
ObjectName(vec![
1019-
Ident::new("@namespace"),
1020-
Ident::new("stage_name/path"),
1021-
]),
1016+
ObjectName(vec![Ident::new("@namespace.%table_name")]),
1017+
ObjectName(vec![Ident::new("@namespace.%table_name/path")]),
1018+
ObjectName(vec![Ident::new("@namespace.stage_name/path")]),
10221019
ObjectName(vec![Ident::new("@~/path")]),
10231020
];
10241021

@@ -1064,3 +1061,16 @@ fn test_snowflake_trim() {
10641061
snowflake().parse_sql_statements(error_sql).unwrap_err()
10651062
);
10661063
}
1064+
1065+
#[test]
1066+
fn parse_division_correctly() {
1067+
snowflake_and_generic().one_statement_parses_to(
1068+
"SELECT field/1000 FROM tbl1",
1069+
"SELECT field / 1000 FROM tbl1",
1070+
);
1071+
1072+
snowflake_and_generic().one_statement_parses_to(
1073+
"SELECT tbl1.field/tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
1074+
"SELECT tbl1.field / tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
1075+
);
1076+
}

0 commit comments

Comments
 (0)