Skip to content

Commit 7cf83c5

Browse files
committed
Use dialects in the parser for support snowflake uninque parenthesis syntax
Snowflake DB allow single table to be within parenthesis. This behaviour is diffrent than other DB , and it has some impact on the parsing table factor. For supporting we do the following : 1. Add refrence to the dialect in the parser 2. Add Snowflake dialect 3. add function to the dialect trait the identify if single table inside parenthesis allowed 4. When parsing table factor in the allow/deny single table inside parenthesis according to dialect
1 parent 4452f9b commit 7cf83c5

File tree

6 files changed

+312
-48
lines changed

6 files changed

+312
-48
lines changed

src/dialect/mod.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ pub mod keywords;
1616
mod mssql;
1717
mod mysql;
1818
mod postgresql;
19-
19+
mod snowflake;
2020
use std::fmt::Debug;
2121

2222
pub use self::ansi::AnsiDialect;
2323
pub use self::generic::GenericDialect;
2424
pub use self::mssql::MsSqlDialect;
2525
pub use self::mysql::MySqlDialect;
2626
pub use self::postgresql::PostgreSqlDialect;
27+
pub use self::snowflake::SnowflakeDialect;
2728

2829
pub trait Dialect: Debug {
2930
/// Determine if a character starts a quoted identifier. The default
@@ -38,4 +39,8 @@ pub trait Dialect: Debug {
3839
fn is_identifier_start(&self, ch: char) -> bool;
3940
/// Determine if a character is a valid unquoted identifier character
4041
fn is_identifier_part(&self, ch: char) -> bool;
42+
43+
fn alllow_single_table_in_parenthesis(&self) -> bool {
44+
false
45+
}
4146
}

src/dialect/snowflake.rs

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use crate::dialect::Dialect;
2+
3+
#[derive(Debug, Default)]
4+
pub struct SnowflakeDialect;
5+
6+
impl Dialect for SnowflakeDialect {
7+
//Revisit: currently copied from Genric dialect
8+
fn is_identifier_start(&self, ch: char) -> bool {
9+
(ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' || ch == '#' || ch == '@'
10+
}
11+
12+
//Revisit: currently copied from Genric dialect
13+
fn is_identifier_part(&self, ch: char) -> bool {
14+
(ch >= 'a' && ch <= 'z')
15+
|| (ch >= 'A' && ch <= 'Z')
16+
|| (ch >= '0' && ch <= '9')
17+
|| ch == '@'
18+
|| ch == '$'
19+
|| ch == '#'
20+
|| ch == '_'
21+
}
22+
23+
fn alllow_single_table_in_parenthesis(&self) -> bool {
24+
true
25+
}
26+
}

src/parser.rs

+114-9
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,28 @@ impl fmt::Display for ParserError {
8383
impl Error for ParserError {}
8484

8585
/// SQL Parser
86-
pub struct Parser {
86+
pub struct Parser<'a> {
8787
tokens: Vec<Token>,
8888
/// The index of the first unprocessed token in `self.tokens`
8989
index: usize,
90+
dialect: &'a dyn Dialect,
9091
}
9192

92-
impl Parser {
93+
impl<'a> Parser<'a> {
9394
/// Parse the specified tokens
94-
pub fn new(tokens: Vec<Token>) -> Self {
95-
Parser { tokens, index: 0 }
95+
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
96+
Parser {
97+
tokens,
98+
index: 0,
99+
dialect,
100+
}
96101
}
97102

98103
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
99104
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
100105
let mut tokenizer = Tokenizer::new(dialect, &sql);
101106
let tokens = tokenizer.tokenize()?;
102-
let mut parser = Parser::new(tokens);
107+
let mut parser = Parser::new(tokens, dialect);
103108
let mut stmts = Vec::new();
104109
let mut expecting_statement_delimiter = false;
105110
debug!("Parsing sql '{}'...", sql);
@@ -950,7 +955,7 @@ impl Parser {
950955
/// Parse a comma-separated list of 1+ items accepted by `F`
951956
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
952957
where
953-
F: FnMut(&mut Parser) -> Result<T, ParserError>,
958+
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
954959
{
955960
let mut values = vec![];
956961
loop {
@@ -2056,9 +2061,91 @@ impl Parser {
20562061
};
20572062
joins.push(join);
20582063
}
2064+
20592065
Ok(TableWithJoins { relation, joins })
20602066
}
20612067

2068+
fn add_alias_to_single_table_in_parenthesis(
2069+
&self,
2070+
table_facor: TableFactor,
2071+
consumed_alias: TableAlias,
2072+
) -> Result<TableFactor, ParserError> {
2073+
match table_facor {
2074+
// Add the alias to dervied table
2075+
TableFactor::Derived {
2076+
lateral,
2077+
subquery,
2078+
alias,
2079+
} => match alias {
2080+
None => Ok(TableFactor::Derived {
2081+
lateral,
2082+
subquery,
2083+
alias: Some(consumed_alias),
2084+
}),
2085+
// "Select * from (table1 as alias1) as alias1" - it prohabited
2086+
Some(alias) => Err(ParserError::ParserError(format!(
2087+
"duplicate alias {}",
2088+
alias
2089+
))),
2090+
},
2091+
// Add The alias to the table
2092+
TableFactor::Table {
2093+
name,
2094+
alias,
2095+
args,
2096+
with_hints,
2097+
} => match alias {
2098+
None => Ok(TableFactor::Table {
2099+
name,
2100+
alias: Some(consumed_alias),
2101+
args,
2102+
with_hints,
2103+
}),
2104+
// "Select * from (table1 as alias1) as alias1" - it prohabited
2105+
Some(alias) => Err(ParserError::ParserError(format!(
2106+
"duplicate alias {}",
2107+
alias
2108+
))),
2109+
},
2110+
TableFactor::NestedJoin(_) => Err(ParserError::ParserError(
2111+
"aliasing joins is not allowed".to_owned(),
2112+
)),
2113+
}
2114+
}
2115+
2116+
fn remove_redundent_parenthesis(
2117+
&mut self,
2118+
table_and_joins: TableWithJoins,
2119+
) -> Result<TableFactor, ParserError> {
2120+
let table_factor = table_and_joins.relation;
2121+
2122+
// check if we have alias after the parenthesis
2123+
let alias = match self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)? {
2124+
None => {
2125+
return Ok(table_factor);
2126+
}
2127+
Some(alias) => alias,
2128+
};
2129+
2130+
// if we have alias, we attached it to the single table that inside parenthesis
2131+
self.add_alias_to_single_table_in_parenthesis(table_factor, alias)
2132+
}
2133+
2134+
fn validate_nested_join(&self, table_and_joins: &TableWithJoins) -> Result<(), ParserError> {
2135+
match table_and_joins.relation {
2136+
TableFactor::NestedJoin { .. } => (),
2137+
_ => {
2138+
if table_and_joins.joins.is_empty() {
2139+
// validate thats indeed join and not dervied
2140+
// or nested table
2141+
self.expected("joined table", self.peek_token())?
2142+
}
2143+
}
2144+
}
2145+
2146+
Ok(())
2147+
}
2148+
20622149
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
20632150
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
20642151
if self.parse_keyword(Keyword::LATERAL) {
@@ -2102,10 +2189,28 @@ impl Parser {
21022189
// followed by some joins or another level of nesting.
21032190
let table_and_joins = self.parse_table_and_joins()?;
21042191
self.expect_token(&Token::RParen)?;
2192+
21052193
// The SQL spec prohibits derived and bare tables from appearing
2106-
// alone in parentheses. We don't enforce this as some databases
2107-
// (e.g. Snowflake) allow such syntax.
2108-
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
2194+
// alone in parentheses. But as some databases
2195+
// (e.g. Snowflake) allow such syntax - it's can be allowed
2196+
// for specfic dialect.
2197+
if self.dialect.alllow_single_table_in_parenthesis() {
2198+
if table_and_joins.joins.is_empty() {
2199+
// In case the DB's like snowflake that allowed single dervied or bare
2200+
// table in parenthesis (for example : `Select * from (a) as b` )
2201+
// the parser will parse it as Nested join, but if it's actually a single table
2202+
// we don't want to treat such case as join , because we don't actually join
2203+
// any tables.
2204+
let table_factor = self.remove_redundent_parenthesis(table_and_joins)?;
2205+
Ok(table_factor)
2206+
} else {
2207+
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
2208+
}
2209+
} else {
2210+
// Defualt behaviuor
2211+
self.validate_nested_join(&table_and_joins)?;
2212+
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
2213+
}
21092214
} else {
21102215
let name = self.parse_object_name()?;
21112216
// Postgres, MSSQL: table-valued functions:

src/test_utils.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl TestedDialects {
5353
self.one_of_identical_results(|dialect| {
5454
let mut tokenizer = Tokenizer::new(dialect, sql);
5555
let tokens = tokenizer.tokenize().unwrap();
56-
f(&mut Parser::new(tokens))
56+
f(&mut Parser::new(tokens, dialect))
5757
})
5858
}
5959

@@ -104,7 +104,9 @@ impl TestedDialects {
104104
/// Ensures that `sql` parses as an expression, and is not modified
105105
/// after a serialization round-trip.
106106
pub fn verified_expr(&self, sql: &str) -> Expr {
107-
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
107+
let ast = self
108+
.run_parser_method(sql, |parser| parser.parse_expr())
109+
.unwrap();
108110
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109111
ast
110112
}

tests/sqlparser_common.rs

+10-36
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use matches::assert_matches;
2222

2323
use sqlparser::ast::*;
2424
use sqlparser::dialect::keywords::ALL_KEYWORDS;
25-
use sqlparser::parser::{Parser, ParserError};
25+
use sqlparser::parser::ParserError;
2626
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
2727

2828
#[test]
@@ -147,13 +147,14 @@ fn parse_update() {
147147

148148
#[test]
149149
fn parse_invalid_table_name() {
150-
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
150+
let ast = all_dialects()
151+
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
151152
assert!(ast.is_err());
152153
}
153154

154155
#[test]
155156
fn parse_no_table_name() {
156-
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
157+
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
157158
assert!(ast.is_err());
158159
}
159160

@@ -2273,19 +2274,12 @@ fn parse_join_nesting() {
22732274
vec![join(nest!(nest!(nest!(table("b"), table("c")))))]
22742275
);
22752276

2276-
// Parenthesized table names are non-standard, but supported in Snowflake SQL
2277-
let sql = "SELECT * FROM (a NATURAL JOIN (b))";
2278-
let select = verified_only_select(sql);
2279-
let from = only(select.from);
2280-
2281-
assert_eq!(from.relation, nest!(table("a"), nest!(table("b"))));
2282-
2283-
// Double parentheses around table names are non-standard, but supported in Snowflake SQL
2284-
let sql = "SELECT * FROM (a NATURAL JOIN ((b)))";
2285-
let select = verified_only_select(sql);
2286-
let from = only(select.from);
2287-
2288-
assert_eq!(from.relation, nest!(table("a"), nest!(nest!(table("b")))));
2277+
// Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL
2278+
let res = parse_sql_statements("SELECT * FROM ((SELECT 1) AS t)");
2279+
assert_eq!(
2280+
ParserError::ParserError("Expected joined table, found: EOF".to_string()),
2281+
res.unwrap_err()
2282+
);
22892283
}
22902284

22912285
#[test]
@@ -2427,26 +2421,6 @@ fn parse_derived_tables() {
24272421
}],
24282422
}))
24292423
);
2430-
2431-
// Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL
2432-
let sql = "SELECT * FROM ((SELECT 1) AS t)";
2433-
let select = verified_only_select(sql);
2434-
let from = only(select.from);
2435-
2436-
assert_eq!(
2437-
from.relation,
2438-
TableFactor::NestedJoin(Box::new(TableWithJoins {
2439-
relation: TableFactor::Derived {
2440-
lateral: false,
2441-
subquery: Box::new(verified_query("SELECT 1")),
2442-
alias: Some(TableAlias {
2443-
name: "t".into(),
2444-
columns: vec![],
2445-
})
2446-
},
2447-
joins: Vec::new(),
2448-
}))
2449-
);
24502424
}
24512425

24522426
#[test]

0 commit comments

Comments
 (0)