Skip to content

Commit 1b46e82

Browse files
eyalleshemc7hm4rnickolay
authored
Enable dialect specific behaviours in the parser (#254)
* Change `Parser { ... }` to store the dialect used: `Parser<'a> { ... dialect: &'a dyn Dialect }` Thanks to @c7hm4r for the initial version of this submitted as part of #170 * Introduce `dialect_of!(parser is SQLiteDialect | GenericDialect)` helper to branch on the dialect's type * Use the new functionality to make `AUTO_INCREMENT` and `AUTOINCREMENT` parsing dialect-dependent. Co-authored-by: Christoph Müller <[email protected]> Co-authored-by: Nickolay Ponomarev <[email protected]>
1 parent 3871bbc commit 1b46e82

File tree

5 files changed

+85
-17
lines changed

5 files changed

+85
-17
lines changed

src/dialect/mod.rs

+58-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod mysql;
1818
mod postgresql;
1919
mod sqlite;
2020

21+
use std::any::{Any, TypeId};
2122
use std::fmt::Debug;
2223

2324
pub use self::ansi::AnsiDialect;
@@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
2728
pub use self::postgresql::PostgreSqlDialect;
2829
pub use self::sqlite::SQLiteDialect;
2930

30-
pub trait Dialect: Debug {
31+
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
32+
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
33+
macro_rules! dialect_of {
34+
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
35+
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
36+
};
37+
}
38+
39+
pub trait Dialect: Debug + Any {
3140
/// Determine if a character starts a quoted identifier. The default
3241
/// implementation, accepting "double quoted" ids is both ANSI-compliant
3342
/// and appropriate for most dialects (with the notable exception of
@@ -41,3 +50,51 @@ pub trait Dialect: Debug {
4150
/// Determine if a character is a valid unquoted identifier character
4251
fn is_identifier_part(&self, ch: char) -> bool;
4352
}
53+
54+
impl dyn Dialect {
55+
#[inline]
56+
pub fn is<T: Dialect>(&self) -> bool {
57+
// borrowed from `Any` implementation
58+
TypeId::of::<T>() == self.type_id()
59+
}
60+
}
61+
62+
#[cfg(test)]
63+
mod tests {
64+
use super::ansi::AnsiDialect;
65+
use super::generic::GenericDialect;
66+
use super::*;
67+
68+
struct DialectHolder<'a> {
69+
dialect: &'a dyn Dialect,
70+
}
71+
72+
#[test]
73+
fn test_is_dialect() {
74+
let generic_dialect: &dyn Dialect = &GenericDialect {};
75+
let ansi_dialect: &dyn Dialect = &AnsiDialect {};
76+
77+
let generic_holder = DialectHolder {
78+
dialect: generic_dialect,
79+
};
80+
let ansi_holder = DialectHolder {
81+
dialect: ansi_dialect,
82+
};
83+
84+
assert_eq!(
85+
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
86+
true
87+
);
88+
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);
89+
90+
assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
91+
assert_eq!(
92+
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
93+
true
94+
);
95+
assert_eq!(
96+
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
97+
false
98+
);
99+
}
100+
}

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#![warn(clippy::all)]
3636

3737
pub mod ast;
38+
#[macro_use]
3839
pub mod dialect;
3940
pub mod parser;
4041
pub mod tokenizer;

src/parser.rs

+18-11
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
use log::debug;
1616

1717
use super::ast::*;
18-
use super::dialect::keywords;
1918
use super::dialect::keywords::Keyword;
20-
use super::dialect::Dialect;
19+
use super::dialect::*;
2120
use super::tokenizer::*;
2221
use std::error::Error;
2322
use std::fmt;
@@ -82,24 +81,28 @@ impl fmt::Display for ParserError {
8281

8382
impl Error for ParserError {}
8483

85-
/// SQL Parser
86-
pub struct Parser {
84+
pub struct Parser<'a> {
8785
tokens: Vec<Token>,
8886
/// The index of the first unprocessed token in `self.tokens`
8987
index: usize,
88+
dialect: &'a dyn Dialect,
9089
}
9190

92-
impl Parser {
91+
impl<'a> Parser<'a> {
9392
/// Parse the specified tokens
94-
pub fn new(tokens: Vec<Token>) -> Self {
95-
Parser { tokens, index: 0 }
93+
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
94+
Parser {
95+
tokens,
96+
index: 0,
97+
dialect,
98+
}
9699
}
97100

98101
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
99102
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
100103
let mut tokenizer = Tokenizer::new(dialect, &sql);
101104
let tokens = tokenizer.tokenize()?;
102-
let mut parser = Parser::new(tokens);
105+
let mut parser = Parser::new(tokens, dialect);
103106
let mut stmts = Vec::new();
104107
let mut expecting_statement_delimiter = false;
105108
debug!("Parsing sql '{}'...", sql);
@@ -950,7 +953,7 @@ impl Parser {
950953
/// Parse a comma-separated list of 1+ items accepted by `F`
951954
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
952955
where
953-
F: FnMut(&mut Parser) -> Result<T, ParserError>,
956+
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
954957
{
955958
let mut values = vec![];
956959
loop {
@@ -1285,10 +1288,14 @@ impl Parser {
12851288
let expr = self.parse_expr()?;
12861289
self.expect_token(&Token::RParen)?;
12871290
ColumnOption::Check(expr)
1288-
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
1291+
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
1292+
&& dialect_of!(self is MySqlDialect | GenericDialect)
1293+
{
12891294
// Support AUTO_INCREMENT for MySQL
12901295
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
1291-
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
1296+
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
1297+
&& dialect_of!(self is SQLiteDialect | GenericDialect)
1298+
{
12921299
// Support AUTOINCREMENT for SQLite
12931300
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
12941301
} else {

src/test_utils.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl TestedDialects {
5353
self.one_of_identical_results(|dialect| {
5454
let mut tokenizer = Tokenizer::new(dialect, sql);
5555
let tokens = tokenizer.tokenize().unwrap();
56-
f(&mut Parser::new(tokens))
56+
f(&mut Parser::new(tokens, dialect))
5757
})
5858
}
5959

@@ -104,7 +104,9 @@ impl TestedDialects {
104104
/// Ensures that `sql` parses as an expression, and is not modified
105105
/// after a serialization round-trip.
106106
pub fn verified_expr(&self, sql: &str) -> Expr {
107-
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
107+
let ast = self
108+
.run_parser_method(sql, |parser| parser.parse_expr())
109+
.unwrap();
108110
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
109111
ast
110112
}

tests/sqlparser_common.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use matches::assert_matches;
2222

2323
use sqlparser::ast::*;
2424
use sqlparser::dialect::keywords::ALL_KEYWORDS;
25-
use sqlparser::parser::{Parser, ParserError};
25+
use sqlparser::parser::ParserError;
2626
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
2727

2828
#[test]
@@ -147,13 +147,14 @@ fn parse_update() {
147147

148148
#[test]
149149
fn parse_invalid_table_name() {
150-
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
150+
let ast = all_dialects()
151+
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
151152
assert!(ast.is_err());
152153
}
153154

154155
#[test]
155156
fn parse_no_table_name() {
156-
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
157+
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
157158
assert!(ast.is_err());
158159
}
159160

0 commit comments

Comments
 (0)