Skip to content

Commit b256730

Browse files
lustefaniakserprex
authored andcommitted
snowflake: Fix handling of /~% in the stage name (apache#1009)
1 parent 47bd477 commit b256730

File tree

3 files changed

+72
-5
lines changed

3 files changed

+72
-5
lines changed

src/dialect/snowflake.rs

+43-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct SnowflakeDialect;
3535
impl Dialect for SnowflakeDialect {
3636
// see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html
3737
fn is_identifier_start(&self, ch: char) -> bool {
38-
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' || ch == '@' || ch == '%'
38+
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
3939
}
4040

4141
fn is_identifier_part(&self, ch: char) -> bool {
@@ -44,8 +44,6 @@ impl Dialect for SnowflakeDialect {
4444
|| ch.is_ascii_digit()
4545
|| ch == '$'
4646
|| ch == '_'
47-
|| ch == '/'
48-
|| ch == '~'
4947
}
5048

5149
fn supports_within_after_array_aggregation(&self) -> bool {
@@ -148,8 +146,48 @@ pub fn parse_create_stage(
148146
})
149147
}
150148

149+
pub fn parse_stage_name_identifier(parser: &mut Parser) -> Result<Ident, ParserError> {
150+
let mut ident = String::new();
151+
while let Some(next_token) = parser.next_token_no_skip() {
152+
match &next_token.token {
153+
Token::Whitespace(_) => break,
154+
Token::Period => {
155+
parser.prev_token();
156+
break;
157+
}
158+
Token::AtSign => ident.push('@'),
159+
Token::Tilde => ident.push('~'),
160+
Token::Mod => ident.push('%'),
161+
Token::Div => ident.push('/'),
162+
Token::Word(w) => ident.push_str(&w.value),
163+
_ => return parser.expected("stage name identifier", parser.peek_token()),
164+
}
165+
}
166+
Ok(Ident::new(ident))
167+
}
168+
169+
pub fn parse_snowflake_stage_name(parser: &mut Parser) -> Result<ObjectName, ParserError> {
170+
match parser.next_token().token {
171+
Token::AtSign => {
172+
parser.prev_token();
173+
let mut idents = vec![];
174+
loop {
175+
idents.push(parse_stage_name_identifier(parser)?);
176+
if !parser.consume_token(&Token::Period) {
177+
break;
178+
}
179+
}
180+
Ok(ObjectName(idents))
181+
}
182+
_ => {
183+
parser.prev_token();
184+
Ok(parser.parse_object_name()?)
185+
}
186+
}
187+
}
188+
151189
pub fn parse_copy_into(parser: &mut Parser) -> Result<Statement, ParserError> {
152-
let into: ObjectName = parser.parse_object_name()?;
190+
let into: ObjectName = parse_snowflake_stage_name(parser)?;
153191
let mut files: Vec<String> = vec![];
154192
let mut from_transformations: Option<Vec<StageLoadSelectItem>> = None;
155193
let from_stage_alias;
@@ -165,7 +203,7 @@ pub fn parse_copy_into(parser: &mut Parser) -> Result<Statement, ParserError> {
165203
from_transformations = parse_select_items_for_data_load(parser)?;
166204

167205
parser.expect_keyword(Keyword::FROM)?;
168-
from_stage = parser.parse_object_name()?;
206+
from_stage = parse_snowflake_stage_name(parser)?;
169207
stage_params = parse_stage_params(parser)?;
170208

171209
// as

src/tokenizer.rs

+13
Original file line numberDiff line numberDiff line change
@@ -2001,6 +2001,19 @@ mod tests {
20012001
compare(expected, tokens);
20022002
}
20032003

2004+
#[test]
2005+
fn tokenize_snowflake_div() {
2006+
let sql = r#"field/1000"#;
2007+
let dialect = SnowflakeDialect {};
2008+
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
2009+
let expected = vec![
2010+
Token::make_word(r#"field"#, None),
2011+
Token::Div,
2012+
Token::Number("1000".to_string(), false),
2013+
];
2014+
compare(expected, tokens);
2015+
}
2016+
20042017
#[test]
20052018
fn tokenize_quoted_identifier_with_no_escape() {
20062019
let sql = r#" "a "" b" "a """ "c """"" "#;

tests/sqlparser_snowflake.rs

+16
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ use test_utils::*;
2626
#[macro_use]
2727
mod test_utils;
2828

29+
#[cfg(test)]
30+
use pretty_assertions::assert_eq;
31+
2932
#[test]
3033
fn test_snowflake_create_table() {
3134
let sql = "CREATE TABLE _my_$table (am00unt number)";
@@ -1118,3 +1121,16 @@ fn parse_subquery_function_argument() {
11181121
// the function.
11191122
snowflake().one_statement_parses_to("SELECT func(SELECT 1, 2)", "SELECT func((SELECT 1, 2))");
11201123
}
1124+
1125+
#[test]
1126+
fn parse_division_correctly() {
1127+
snowflake_and_generic().one_statement_parses_to(
1128+
"SELECT field/1000 FROM tbl1",
1129+
"SELECT field / 1000 FROM tbl1",
1130+
);
1131+
1132+
snowflake_and_generic().one_statement_parses_to(
1133+
"SELECT tbl1.field/tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
1134+
"SELECT tbl1.field / tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
1135+
);
1136+
}

0 commit comments

Comments
 (0)