Skip to content

Enable dialect specific behaviours in the parser #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod mysql;
mod postgresql;
mod sqlite;

use std::any::{Any, TypeId};
use std::fmt::Debug;

pub use self::ansi::AnsiDialect;
Expand All @@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
pub use self::postgresql::PostgreSqlDialect;
pub use self::sqlite::SQLiteDialect;

pub trait Dialect: Debug {
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
macro_rules! dialect_of {
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
};
}

pub trait Dialect: Debug + Any {
/// Determine if a character starts a quoted identifier. The default
/// implementation, accepting "double quoted" ids is both ANSI-compliant
/// and appropriate for most dialects (with the notable exception of
Expand All @@ -41,3 +50,51 @@ pub trait Dialect: Debug {
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
}

impl dyn Dialect {
#[inline]
pub fn is<T: Dialect>(&self) -> bool {
// borrowed from `Any` implementation
TypeId::of::<T>() == self.type_id()
}
}

#[cfg(test)]
mod tests {
use super::ansi::AnsiDialect;
use super::generic::GenericDialect;
use super::*;

struct DialectHolder<'a> {
dialect: &'a dyn Dialect,
}

#[test]
fn test_is_dialect() {
let generic_dialect: &dyn Dialect = &GenericDialect {};
let ansi_dialect: &dyn Dialect = &AnsiDialect {};

let generic_holder = DialectHolder {
dialect: generic_dialect,
};
let ansi_holder = DialectHolder {
dialect: ansi_dialect,
};

assert_eq!(
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);

assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
false
);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#![warn(clippy::all)]

pub mod ast;
#[macro_use]
pub mod dialect;
pub mod parser;
pub mod tokenizer;
Expand Down
29 changes: 18 additions & 11 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
use log::debug;

use super::ast::*;
use super::dialect::keywords;
use super::dialect::keywords::Keyword;
use super::dialect::Dialect;
use super::dialect::*;
use super::tokenizer::*;
use std::error::Error;
use std::fmt;
Expand Down Expand Up @@ -82,24 +81,28 @@ impl fmt::Display for ParserError {

impl Error for ParserError {}

/// SQL Parser
pub struct Parser {
pub struct Parser<'a> {
tokens: Vec<Token>,
/// The index of the first unprocessed token in `self.tokens`
index: usize,
dialect: &'a dyn Dialect,
}

impl Parser {
impl<'a> Parser<'a> {
/// Parse the specified tokens
pub fn new(tokens: Vec<Token>) -> Self {
Parser { tokens, index: 0 }
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser {
tokens,
index: 0,
dialect,
}
}

/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens);
let mut parser = Parser::new(tokens, dialect);
let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
Expand Down Expand Up @@ -950,7 +953,7 @@ impl Parser {
/// Parse a comma-separated list of 1+ items accepted by `F`
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
where
F: FnMut(&mut Parser) -> Result<T, ParserError>,
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
{
let mut values = vec![];
loop {
Expand Down Expand Up @@ -1285,10 +1288,14 @@ impl Parser {
let expr = self.parse_expr()?;
self.expect_token(&Token::RParen)?;
ColumnOption::Check(expr)
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
&& dialect_of!(self is MySqlDialect | GenericDialect)
{
// Support AUTO_INCREMENT for MySQL
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
&& dialect_of!(self is SQLiteDialect | GenericDialect)
{
// Support AUTOINCREMENT for SQLite
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
} else {
Expand Down
6 changes: 4 additions & 2 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl TestedDialects {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens))
f(&mut Parser::new(tokens, dialect))
})
}

Expand Down Expand Up @@ -104,7 +104,9 @@ impl TestedDialects {
/// Ensures that `sql` parses as an expression, and is not modified
/// after a serialization round-trip.
pub fn verified_expr(&self, sql: &str) -> Expr {
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
let ast = self
.run_parser_method(sql, |parser| parser.parse_expr())
.unwrap();
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
ast
}
Expand Down
7 changes: 4 additions & 3 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use matches::assert_matches;

use sqlparser::ast::*;
use sqlparser::dialect::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};

#[test]
Expand Down Expand Up @@ -147,13 +147,14 @@ fn parse_update() {

#[test]
fn parse_invalid_table_name() {
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
let ast = all_dialects()
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
assert!(ast.is_err());
}

#[test]
fn parse_no_table_name() {
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
assert!(ast.is_err());
}

Expand Down