diff --git a/examples/cli.rs b/examples/cli.rs index 38b3de841..eb0cbde9b 100644 --- a/examples/cli.rs +++ b/examples/cli.rs @@ -36,18 +36,24 @@ $ cargo run --feature json_example --example cli FILENAME.sql [--dialectname] "#, ); - let dialect: Box = match std::env::args().nth(2).unwrap_or_default().as_ref() { - "--ansi" => Box::new(AnsiDialect {}), - "--postgres" => Box::new(PostgreSqlDialect {}), - "--ms" => Box::new(MsSqlDialect {}), - "--mysql" => Box::new(MySqlDialect {}), - "--snowflake" => Box::new(SnowflakeDialect {}), - "--hive" => Box::new(HiveDialect {}), - "--generic" | "" => Box::new(GenericDialect {}), + match std::env::args().nth(2).unwrap_or_default().as_ref() { + "--ansi" => parse::(filename), + "--postgres" => parse::(filename), + "--ms" => parse::(filename), + "--mysql" => parse::(filename), + "--snowflake" => parse::(filename), + "--hive" => parse::(filename), + "--generic" | "" => parse::(filename), s => panic!("Unexpected parameter: {}", s), }; +} - println!("Parsing from file '{}' using {:?}", &filename, dialect); +fn parse(filename: String) { + println!( + "Parsing from file '{}' using {:?}", + &filename, + std::any::type_name::() + ); let contents = fs::read_to_string(&filename) .unwrap_or_else(|_| panic!("Unable to read the file {}", &filename)); let without_bom = if contents.chars().next().unwrap() as u64 != 0xfeff { @@ -57,7 +63,7 @@ $ cargo run --feature json_example --example cli FILENAME.sql [--dialectname] chars.next(); chars.as_str() }; - let parse_result = Parser::parse_sql(&*dialect, without_bom); + let parse_result = Parser::::parse_sql(without_bom); match parse_result { Ok(statements) => { println!( diff --git a/examples/parse_select.rs b/examples/parse_select.rs index e7aa16307..85134a8a9 100644 --- a/examples/parse_select.rs +++ b/examples/parse_select.rs @@ -21,9 +21,7 @@ fn main() { WHERE a > b AND b < 100 \ ORDER BY a DESC, b"; - let dialect = GenericDialect {}; - - let ast = Parser::parse_sql(&dialect, sql).unwrap(); + let ast = Parser::::parse_sql(sql).unwrap(); println!("AST: {:?}", ast); } diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs index 1015ca2d3..5752175dc 100644 --- a/src/dialect/ansi.rs +++ b/src/dialect/ansi.rs @@ -16,11 +16,11 @@ use crate::dialect::Dialect; pub struct AnsiDialect {} impl Dialect for AnsiDialect { - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 24ec5e49f..074914137 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -16,12 +16,12 @@ use crate::dialect::Dialect; pub struct ClickHouseDialect {} impl Dialect for ClickHouseDialect { - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { // See https://clickhouse.com/docs/en/sql-reference/syntax/#syntax-identifiers ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' } - fn is_identifier_part(&self, ch: char) -> bool { - self.is_identifier_start(ch) || ('0'..='9').contains(&ch) + fn is_identifier_part(ch: char) -> bool { + Self::is_identifier_start(ch) || ('0'..='9').contains(&ch) } } diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 818fa0d0a..6a161e9d9 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -16,7 +16,7 @@ use crate::dialect::Dialect; pub struct GenericDialect; impl Dialect for GenericDialect { - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' @@ -24,7 +24,7 @@ impl Dialect for GenericDialect { || ch == '@' } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 9b42857ec..75b9abe10 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -16,18 +16,18 @@ use crate::dialect::Dialect; pub struct HiveDialect {} impl Dialect for HiveDialect { - fn is_delimited_identifier_start(&self, ch: char) -> bool { + fn is_delimited_identifier_start(ch: char) -> bool { (ch == '"') || (ch == '`') } - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) || ch == '$' } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 008b099d2..974c3d472 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -34,63 +34,21 @@ pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; pub use crate::keywords; -/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates -/// to `true` if `parser.dialect` is one of the `Dialect`s specified. -macro_rules! dialect_of { - ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { - ($($parsed_dialect.dialect.is::<$dialect_type>())||+) - }; -} - pub trait Dialect: Debug + Any { /// Determine if a character starts a quoted identifier. The default /// implementation, accepting "double quoted" ids is both ANSI-compliant /// and appropriate for most dialects (with the notable exception of /// MySQL, MS SQL, and sqlite). You can accept one of characters listed /// in `Word::matching_end_quote` here - fn is_delimited_identifier_start(&self, ch: char) -> bool { + fn is_delimited_identifier_start(ch: char) -> bool { ch == '"' } /// Determine if a character is a valid start character for an unquoted identifier - fn is_identifier_start(&self, ch: char) -> bool; + fn is_identifier_start(ch: char) -> bool; /// Determine if a character is a valid unquoted identifier character - fn is_identifier_part(&self, ch: char) -> bool; -} - -impl dyn Dialect { - #[inline] - pub fn is(&self) -> bool { - // borrowed from `Any` implementation - TypeId::of::() == self.type_id() - } -} - -#[cfg(test)] -mod tests { - use super::ansi::AnsiDialect; - use super::generic::GenericDialect; - use super::*; - - struct DialectHolder<'a> { - dialect: &'a dyn Dialect, - } - - #[test] - fn test_is_dialect() { - let generic_dialect: &dyn Dialect = &GenericDialect {}; - let ansi_dialect: &dyn Dialect = &AnsiDialect {}; - - let generic_holder = DialectHolder { - dialect: generic_dialect, - }; - let ansi_holder = DialectHolder { - dialect: ansi_dialect, - }; + fn is_identifier_part(ch: char) -> bool; - assert!(dialect_of!(generic_holder is GenericDialect | AnsiDialect),); - assert!(!dialect_of!(generic_holder is AnsiDialect)); - assert!(dialect_of!(ansi_holder is AnsiDialect)); - assert!(dialect_of!(ansi_holder is GenericDialect | AnsiDialect)); - assert!(!dialect_of!(ansi_holder is GenericDialect | MsSqlDialect)); + fn is() -> bool { + TypeId::of::() == TypeId::of::() } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 539a17a9f..585809848 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -16,11 +16,11 @@ use crate::dialect::Dialect; pub struct MsSqlDialect {} impl Dialect for MsSqlDialect { - fn is_delimited_identifier_start(&self, ch: char) -> bool { + fn is_delimited_identifier_start(ch: char) -> bool { ch == '"' || ch == '[' } - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { // See https://docs.microsoft.com/en-us/sql/relational-databases/databases/database-identifiers?view=sql-server-2017#rules-for-regular-identifiers // We don't support non-latin "letters" currently. ('a'..='z').contains(&ch) @@ -30,7 +30,7 @@ impl Dialect for MsSqlDialect { || ch == '@' } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 6581195b8..d3010e2a3 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -16,7 +16,7 @@ use crate::dialect::Dialect; pub struct MySqlDialect {} impl Dialect for MySqlDialect { - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { // See https://dev.mysql.com/doc/refman/8.0/en/identifiers.html. // We don't yet support identifiers beginning with numbers, as that // makes it hard to distinguish numeric literals. @@ -27,11 +27,11 @@ impl Dialect for MySqlDialect { || ('\u{0080}'..='\u{ffff}').contains(&ch) } - fn is_identifier_part(&self, ch: char) -> bool { - self.is_identifier_start(ch) || ('0'..='9').contains(&ch) + fn is_identifier_part(ch: char) -> bool { + Self::is_identifier_start(ch) || ('0'..='9').contains(&ch) } - fn is_delimited_identifier_start(&self, ch: char) -> bool { + fn is_delimited_identifier_start(ch: char) -> bool { ch == '`' } } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 0c2eb99f0..8bc51ea6c 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -13,17 +13,17 @@ use crate::dialect::Dialect; #[derive(Debug)] -pub struct PostgreSqlDialect {} +pub struct PostgreSqlDialect; impl Dialect for PostgreSqlDialect { - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS // We don't yet support identifiers beginning with "letters with // diacritical marks and non-Latin letters" ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 93db95692..0df7160a3 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -17,11 +17,11 @@ pub struct SnowflakeDialect; impl Dialect for SnowflakeDialect { // see https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' } - fn is_identifier_part(&self, ch: char) -> bool { + fn is_identifier_part(ch: char) -> bool { ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ('0'..='9').contains(&ch) diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 4ce2f834b..2f46224ef 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -19,11 +19,11 @@ impl Dialect for SQLiteDialect { // see https://www.sqlite.org/lang_keywords.html // parse `...`, [...] and "..." as identifier // TODO: support depending on the context tread '...' as identifier too. - fn is_delimited_identifier_start(&self, ch: char) -> bool { + fn is_delimited_identifier_start(ch: char) -> bool { ch == '`' || ch == '"' || ch == '[' } - fn is_identifier_start(&self, ch: char) -> bool { + fn is_identifier_start(ch: char) -> bool { // See https://www.sqlite.org/draft/tokenreq.html ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) @@ -32,7 +32,7 @@ impl Dialect for SQLiteDialect { || ('\u{007f}'..='\u{ffff}').contains(&ch) } - fn is_identifier_part(&self, ch: char) -> bool { - self.is_identifier_start(ch) || ('0'..='9').contains(&ch) + fn is_identifier_part(ch: char) -> bool { + Self::is_identifier_start(ch) || ('0'..='9').contains(&ch) } } diff --git a/src/lib.rs b/src/lib.rs index f04ae07a9..3e4100c3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,14 +21,12 @@ //! use sqlparser::dialect::GenericDialect; //! use sqlparser::parser::Parser; //! -//! let dialect = GenericDialect {}; // or AnsiDialect -//! //! let sql = "SELECT a, b, 123, myfunc(b) \ //! FROM table_1 \ //! WHERE a > b AND b < 100 \ //! ORDER BY a DESC, b"; //! -//! let ast = Parser::parse_sql(&dialect, sql).unwrap(); +//! let ast = Parser::::parse_sql(sql).unwrap(); //! //! println!("AST: {:?}", ast); //! ``` @@ -45,9 +43,3 @@ pub mod dialect; pub mod keywords; pub mod parser; pub mod tokenizer; - -#[doc(hidden)] -// This is required to make utilities accessible by both the crate-internal -// unit-tests and by the integration tests -// External users are not supposed to rely on this module. -pub mod test_utils; diff --git a/src/parser.rs b/src/parser.rs index 6d917f027..1242401cc 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -21,6 +21,7 @@ use alloc::{ vec::Vec, }; use core::fmt; +use core::marker::PhantomData; use log::debug; @@ -104,45 +105,49 @@ impl fmt::Display for ParserError { #[cfg(feature = "std")] impl std::error::Error for ParserError {} -pub struct Parser<'a> { +pub struct Parser { tokens: Vec, /// The index of the first unprocessed token in `self.tokens` index: usize, - dialect: &'a dyn Dialect, + dialect: PhantomData, } -impl<'a> Parser<'a> { +impl Parser { /// Parse the specified tokens - pub fn new(tokens: Vec, dialect: &'a dyn Dialect) -> Self { + pub fn new(tokens: Vec) -> Self { Parser { tokens, index: 0, - dialect, + dialect: PhantomData, } } /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) - pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result, ParserError> { - let mut tokenizer = Tokenizer::new(dialect, sql); + pub fn parse_sql(sql: &str) -> Result, ParserError> { + let mut tokenizer = Tokenizer::::new(sql); let tokens = tokenizer.tokenize()?; - let mut parser = Parser::new(tokens, dialect); + let mut parser = Parser::::new(tokens); + debug!("Parsing sql '{}'...", sql); + parser.parse_statements() + } + + pub fn parse_statements(&mut self) -> Result, ParserError> { let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; - debug!("Parsing sql '{}'...", sql); loop { // ignore empty statements (between successive statement delimiters) - while parser.consume_token(&Token::SemiColon) { + while self.consume_token(&Token::SemiColon) { expecting_statement_delimiter = false; } - if parser.peek_token() == Token::EOF { + if self.peek_token() == Token::EOF { break; } if expecting_statement_delimiter { - return parser.expected("end of statement", parser.peek_token()); + return self.expected("end of statement", self.peek_token()); } - let statement = parser.parse_statement()?; + let statement = self.parse_statement()?; stmts.push(statement); expecting_statement_delimiter = true; } @@ -187,13 +192,11 @@ impl<'a> Parser<'a> { Keyword::DEALLOCATE => Ok(self.parse_deallocate()?), Keyword::EXECUTE => Ok(self.parse_execute()?), Keyword::PREPARE => Ok(self.parse_prepare()?), - Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => { + Keyword::REPLACE if D::is::() => { self.prev_token(); Ok(self.parse_insert()?) } - Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => { - Ok(self.parse_comment()?) - } + Keyword::COMMENT if D::is::() => Ok(self.parse_comment()?), _ => self.expected("an SQL statement", Token::Word(w)), }, Token::LParen => { @@ -462,7 +465,7 @@ impl<'a> Parser<'a> { | tok @ Token::PGCubeRoot | tok @ Token::AtSign | tok @ Token::Tilde - if dialect_of!(self is PostgreSqlDialect) => + if D::is::() => { let op = match tok { Token::DoubleExclamationMark => UnaryOperator::PGPrefixFactorial, @@ -604,7 +607,7 @@ impl<'a> Parser<'a> { /// parse a group by expr. a group by expr can be one of group sets, roll up, cube, or simple /// expr. fn parse_group_by_expr(&mut self) -> Result { - if dialect_of!(self is PostgreSqlDialect) { + if D::is::() { if self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS]) { self.expect_token(&Token::LParen)?; let result = self.parse_comma_separated(|p| p.parse_tuple(false, true))?; @@ -978,15 +981,13 @@ impl<'a> Parser<'a> { Token::Caret => Some(BinaryOperator::BitwiseXor), Token::Ampersand => Some(BinaryOperator::BitwiseAnd), Token::Div => Some(BinaryOperator::Divide), - Token::ShiftLeft if dialect_of!(self is PostgreSqlDialect) => { + Token::ShiftLeft if D::is::() => { Some(BinaryOperator::PGBitwiseShiftLeft) } - Token::ShiftRight if dialect_of!(self is PostgreSqlDialect) => { + Token::ShiftRight if D::is::() => { Some(BinaryOperator::PGBitwiseShiftRight) } - Token::Sharp if dialect_of!(self is PostgreSqlDialect) => { - Some(BinaryOperator::PGBitwiseXor) - } + Token::Sharp if D::is::() => Some(BinaryOperator::PGBitwiseXor), Token::Tilde => Some(BinaryOperator::PGRegexMatch), Token::TildeAsterisk => Some(BinaryOperator::PGRegexIMatch), Token::ExclamationMarkTilde => Some(BinaryOperator::PGRegexNotMatch), @@ -1344,7 +1345,7 @@ impl<'a> Parser<'a> { /// Parse a comma-separated list of 1+ items accepted by `F` pub fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> where - F: FnMut(&mut Parser<'a>) -> Result, + F: FnMut(&mut Parser) -> Result, { let mut values = vec![]; loop { @@ -1361,7 +1362,7 @@ impl<'a> Parser<'a> { #[must_use] fn maybe_parse(&mut self, mut f: F) -> Option where - F: FnMut(&mut Parser) -> Result, + F: FnMut(&mut Parser) -> Result, { let index = self.index; if let Ok(t) = f(self) { @@ -1801,14 +1802,14 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; Ok(Some(ColumnOption::Check(expr))) } else if self.parse_keyword(Keyword::AUTO_INCREMENT) - && dialect_of!(self is MySqlDialect | GenericDialect) + && (D::is::() || D::is::()) { // Support AUTO_INCREMENT for MySQL Ok(Some(ColumnOption::DialectSpecific(vec![ Token::make_keyword("AUTO_INCREMENT"), ]))) } else if self.parse_keyword(Keyword::AUTOINCREMENT) - && dialect_of!(self is SQLiteDialect | GenericDialect) + && (D::is::() || D::is::()) { // Support AUTOINCREMENT for SQLite Ok(Some(ColumnOption::DialectSpecific(vec![ @@ -1947,7 +1948,7 @@ impl<'a> Parser<'a> { } } } else if self.parse_keyword(Keyword::RENAME) { - if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::CONSTRAINT) { + if D::is::() && self.parse_keyword(Keyword::CONSTRAINT) { let old_name = self.parse_identifier()?; self.expect_keyword(Keyword::TO)?; let new_name = self.parse_identifier()?; @@ -2025,7 +2026,7 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::ALTER) { let _ = self.parse_keyword(Keyword::COLUMN); let column_name = self.parse_identifier()?; - let is_postgresql = dialect_of!(self is PostgreSqlDialect); + let is_postgresql = D::is::(); let op = if self.parse_keywords(&[Keyword::SET, Keyword::NOT, Keyword::NULL]) { AlterColumnOperation::SetNotNull {} @@ -2987,7 +2988,7 @@ impl<'a> Parser<'a> { // is a nested join `(foo JOIN bar)`, not followed by other joins. self.expect_token(&Token::RParen)?; Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) - } else if dialect_of!(self is SnowflakeDialect | GenericDialect) { + } else if D::is::() || D::is::() { // Dialect-specific behavior: Snowflake diverges from the // standard and from most of the other implementations by // allowing extra parentheses not only around a join (B), but @@ -3232,7 +3233,7 @@ impl<'a> Parser<'a> { /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { - let or = if !dialect_of!(self is SQLiteDialect) { + let or = if !D::is::() { None } else if self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]) { Some(SqliteOnConflict::Replace) @@ -3629,30 +3630,3 @@ impl Word { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_utils::all_dialects; - - #[test] - fn test_prev_index() { - let sql = "SELECT version"; - all_dialects().run_parser_method(sql, |parser| { - assert_eq!(parser.peek_token(), Token::make_keyword("SELECT")); - assert_eq!(parser.next_token(), Token::make_keyword("SELECT")); - parser.prev_token(); - assert_eq!(parser.next_token(), Token::make_keyword("SELECT")); - assert_eq!(parser.next_token(), Token::make_word("version", None)); - parser.prev_token(); - assert_eq!(parser.peek_token(), Token::make_word("version", None)); - assert_eq!(parser.next_token(), Token::make_word("version", None)); - assert_eq!(parser.peek_token(), Token::EOF); - parser.prev_token(); - assert_eq!(parser.next_token(), Token::make_word("version", None)); - assert_eq!(parser.next_token(), Token::EOF); - assert_eq!(parser.next_token(), Token::EOF); - parser.prev_token(); - }); - } -} diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 296bcc64b..2bcbde552 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -26,13 +26,14 @@ use alloc::{ }; use core::fmt; use core::iter::Peekable; +use core::marker::PhantomData; use core::str::Chars; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::dialect::SnowflakeDialect; use crate::dialect::{Dialect, MySqlDialect}; +use crate::dialect::{GenericDialect, SnowflakeDialect}; use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX}; /// SQL Token enumeration @@ -301,21 +302,21 @@ impl fmt::Display for TokenizerError { impl std::error::Error for TokenizerError {} /// SQL Tokenizer -pub struct Tokenizer<'a> { - dialect: &'a dyn Dialect, +pub struct Tokenizer<'a, D: Dialect = GenericDialect> { query: &'a str, line: u64, col: u64, + dialect: PhantomData, } -impl<'a> Tokenizer<'a> { +impl<'a, D: Dialect> Tokenizer<'a, D> { /// Create a new SQL tokenizer for the specified SQL statement - pub fn new(dialect: &'a dyn Dialect, query: &'a str) -> Self { + pub fn new(query: &'a str) -> Self { Self { - dialect, query, line: 1, col: 1, + dialect: PhantomData, } } @@ -394,7 +395,7 @@ impl<'a> Tokenizer<'a> { } } // identifier or keyword - ch if self.dialect.is_identifier_start(ch) => { + ch if D::is_identifier_start(ch) => { chars.next(); // consume the first char let s = self.tokenize_word(ch, chars); @@ -415,7 +416,7 @@ impl<'a> Tokenizer<'a> { Ok(Some(Token::SingleQuotedString(s))) } // delimited (quoted) identifier - quote_start if self.dialect.is_delimited_identifier_start(quote_start) => { + quote_start if D::is_delimited_identifier_start(quote_start) => { chars.next(); // consume the opening quote let quote_end = Word::matching_end_quote(quote_start); let s = peeking_take_while(chars, |ch| ch != quote_end); @@ -489,7 +490,7 @@ impl<'a> Tokenizer<'a> { chars.next(); // consume the '*', starting a multi-line comment self.tokenize_multiline_comment(chars) } - Some('/') if dialect_of!(self is SnowflakeDialect) => { + Some('/') if D::is::() => { chars.next(); // consume the second '/', starting a snowflake single-line comment let comment = self.tokenize_single_line_comment(chars); Ok(Some(Token::Whitespace(Whitespace::SingleLineComment { @@ -580,7 +581,7 @@ impl<'a> Tokenizer<'a> { '^' => self.consume_and_return(chars, Token::Caret), '{' => self.consume_and_return(chars, Token::LBrace), '}' => self.consume_and_return(chars, Token::RBrace), - '#' if dialect_of!(self is SnowflakeDialect) => { + '#' if D::is::() => { chars.next(); // consume the '#', starting a snowflake single-line comment let comment = self.tokenize_single_line_comment(chars); Ok(Some(Token::Whitespace(Whitespace::SingleLineComment { @@ -624,9 +625,7 @@ impl<'a> Tokenizer<'a> { /// Tokenize an identifier or keyword, after the first char is already consumed. fn tokenize_word(&self, first_char: char, chars: &mut Peekable>) -> String { let mut s = first_char.to_string(); - s.push_str(&peeking_take_while(chars, |ch| { - self.dialect.is_identifier_part(ch) - })); + s.push_str(&peeking_take_while(chars, |ch| D::is_identifier_part(ch))); s } @@ -655,7 +654,7 @@ impl<'a> Tokenizer<'a> { } } '\\' => { - if dialect_of!(self is MySqlDialect) { + if D::is::() { is_escaped = !is_escaped; } else { s.push(ch); @@ -751,8 +750,7 @@ mod tests { #[test] fn tokenize_select_1() { let sql = String::from("SELECT 1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -767,8 +765,7 @@ mod tests { #[test] fn tokenize_select_float() { let sql = String::from("SELECT .1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -783,8 +780,7 @@ mod tests { #[test] fn tokenize_scalar_function() { let sql = String::from("SELECT sqrt(1)"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -802,8 +798,7 @@ mod tests { #[test] fn tokenize_string_string_concat() { let sql = String::from("SELECT 'a' || 'b'"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -821,8 +816,7 @@ mod tests { #[test] fn tokenize_bitwise_op() { let sql = String::from("SELECT one | two ^ three"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -845,8 +839,7 @@ mod tests { fn tokenize_logical_xor() { let sql = String::from("SELECT true XOR true, false XOR false, true XOR false, false XOR true"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -885,8 +878,7 @@ mod tests { #[test] fn tokenize_simple_select() { let sql = String::from("SELECT * FROM customer WHERE id = 1 LIMIT 5"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -917,8 +909,7 @@ mod tests { #[test] fn tokenize_explain_select() { let sql = String::from("EXPLAIN SELECT * FROM customer WHERE id = 1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -947,8 +938,7 @@ mod tests { #[test] fn tokenize_explain_analyze_select() { let sql = String::from("EXPLAIN ANALYZE SELECT * FROM customer WHERE id = 1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -979,8 +969,7 @@ mod tests { #[test] fn tokenize_string_predicate() { let sql = String::from("SELECT * FROM customer WHERE salary != 'Not Provided'"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -1008,8 +997,7 @@ mod tests { fn tokenize_invalid_string() { let sql = String::from("\nمصطفىh"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); // println!("tokens: {:#?}", tokens); let expected = vec![ @@ -1028,8 +1016,7 @@ mod tests { fn tokenize_newline_in_string_literal() { let sql = String::from("'foo\r\nbar\nbaz'"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![Token::SingleQuotedString("foo\r\nbar\nbaz".to_string())]; compare(expected, tokens); @@ -1039,8 +1026,7 @@ mod tests { fn tokenize_unterminated_string_literal() { let sql = String::from("select 'foo"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); assert_eq!( tokenizer.tokenize(), Err(TokenizerError { @@ -1055,8 +1041,7 @@ mod tests { fn tokenize_invalid_string_cols() { let sql = String::from("\n\nSELECT * FROM table\tمصطفىh"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); // println!("tokens: {:#?}", tokens); let expected = vec![ @@ -1083,8 +1068,7 @@ mod tests { #[test] fn tokenize_right_arrow() { let sql = String::from("FUNCTION(key=>value)"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::make_word("FUNCTION", None), @@ -1100,8 +1084,7 @@ mod tests { #[test] fn tokenize_is_null() { let sql = String::from("a IS NULL"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -1119,8 +1102,7 @@ mod tests { fn tokenize_comment() { let sql = String::from("0--this is a comment\n1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::Number("0".to_string(), false), @@ -1137,8 +1119,7 @@ mod tests { fn tokenize_comment_at_eof() { let sql = String::from("--this is a comment"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![Token::Whitespace(Whitespace::SingleLineComment { prefix: "--".to_string(), @@ -1151,8 +1132,7 @@ mod tests { fn tokenize_multiline_comment() { let sql = String::from("0/*multi-line\n* /comment*/1"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::Number("0".to_string(), false), @@ -1168,8 +1148,7 @@ mod tests { fn tokenize_multiline_comment_with_even_asterisks() { let sql = String::from("\n/** Comment **/\n"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::Whitespace(Whitespace::Newline), @@ -1183,8 +1162,7 @@ mod tests { fn tokenize_mismatched_quotes() { let sql = String::from("\"foo"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); assert_eq!( tokenizer.tokenize(), Err(TokenizerError { @@ -1199,8 +1177,7 @@ mod tests { fn tokenize_newlines() { let sql = String::from("line1\nline2\rline3\r\nline4\r"); - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); + let mut tokenizer = Tokenizer::::new(&sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::make_word("line1", None), @@ -1218,8 +1195,7 @@ mod tests { #[test] fn tokenize_mssql_top() { let sql = "SELECT TOP 5 [bar] FROM foo"; - let dialect = MsSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, sql); + let mut tokenizer = Tokenizer::::new(sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::make_keyword("SELECT"), @@ -1240,8 +1216,7 @@ mod tests { #[test] fn tokenize_pg_regex_match() { let sql = "SELECT col ~ '^a', col ~* '^a', col !~ '^a', col !~* '^a'"; - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, sql); + let mut tokenizer = Tokenizer::::new(sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ Token::make_keyword("SELECT"), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 07a0db524..153a6bb7b 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -25,10 +25,32 @@ use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect}; use sqlparser::keywords::ALL_KEYWORDS; use sqlparser::parser::{Parser, ParserError}; +use sqlparser::tokenizer::Token; use test_utils::{ all_dialects, expr_from_projection, join, number, only, table, table_alias, TestedDialects, }; +#[test] +fn test_prev_index() { + let sql = "SELECT version"; + all_dialects().run_parser_method(sql, |parser| { + assert_eq!(parser.peek_token(), Token::make_keyword("SELECT")); + assert_eq!(parser.next_token(), Token::make_keyword("SELECT")); + parser.prev_token(); + assert_eq!(parser.next_token(), Token::make_keyword("SELECT")); + assert_eq!(parser.next_token(), Token::make_word("version", None)); + parser.prev_token(); + assert_eq!(parser.peek_token(), Token::make_word("version", None)); + assert_eq!(parser.next_token(), Token::make_word("version", None)); + assert_eq!(parser.peek_token(), Token::EOF); + parser.prev_token(); + assert_eq!(parser.next_token(), Token::make_word("version", None)); + assert_eq!(parser.next_token(), Token::EOF); + assert_eq!(parser.next_token(), Token::EOF); + parser.prev_token(); + }); +} + #[test] fn parse_insert_values() { let row = vec![ @@ -101,17 +123,15 @@ fn parse_insert_invalid() { #[test] fn parse_insert_sqlite() { - let dialect = SQLiteDialect {}; - - let check = |sql: &str, expected_action: Option| match Parser::parse_sql( - &dialect, sql, - ) - .unwrap() - .pop() - .unwrap() - { - Statement::Insert { or, .. } => assert_eq!(or, expected_action), - _ => panic!("{}", sql), + let check = |sql: &str, expected_action: Option| { + match Parser::::parse_sql(sql) + .unwrap() + .pop() + .unwrap() + { + Statement::Insert { or, .. } => assert_eq!(or, expected_action), + _ => panic!("{}", sql), + } }; let sql = "INSERT INTO test_table(id) VALUES(1)"; @@ -420,7 +440,7 @@ fn test_eof_after_as() { #[test] fn test_no_infix_error() { - let res = Parser::parse_sql(&GenericDialect {}, "ASSERT-URA<<"); + let res = Parser::::parse_sql("ASSERT-URA<<"); assert_eq!( ParserError::ParserError("No infix parser for token ShiftLeft".to_string()), res.unwrap_err() @@ -1100,9 +1120,7 @@ fn parse_select_group_by() { #[test] fn parse_select_group_by_grouping_sets() { - let dialects = TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {})], - }; + let dialects = tested_dialects!(PostgreSqlDialect); let sql = "SELECT brand, size, sum(sales) FROM items_sold GROUP BY size, GROUPING SETS ((brand), (size), ())"; let select = dialects.verified_only_select(sql); @@ -1121,9 +1139,7 @@ fn parse_select_group_by_grouping_sets() { #[test] fn parse_select_group_by_rollup() { - let dialects = TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {})], - }; + let dialects = tested_dialects!(PostgreSqlDialect); let sql = "SELECT brand, size, sum(sales) FROM items_sold GROUP BY size, ROLLUP (brand, size)"; let select = dialects.verified_only_select(sql); assert_eq!( @@ -1140,9 +1156,7 @@ fn parse_select_group_by_rollup() { #[test] fn parse_select_group_by_cube() { - let dialects = TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {})], - }; + let dialects = tested_dialects!(PostgreSqlDialect); let sql = "SELECT brand, size, sum(sales) FROM items_sold GROUP BY size, CUBE (brand, size)"; let select = dialects.verified_only_select(sql); assert_eq!( @@ -2030,22 +2044,19 @@ fn parse_alter_table_alter_column_type() { _ => unreachable!(), } - let res = Parser::parse_sql( - &GenericDialect {}, - &format!("{} ALTER COLUMN is_active TYPE TEXT", alter_stmt), - ); + let res = Parser::::parse_sql(&format!( + "{} ALTER COLUMN is_active TYPE TEXT", + alter_stmt + )); assert_eq!( ParserError::ParserError("Expected SET/DROP NOT NULL, SET DEFAULT, SET DATA TYPE after ALTER COLUMN, found: TYPE".to_string()), res.unwrap_err() ); - let res = Parser::parse_sql( - &GenericDialect {}, - &format!( - "{} ALTER COLUMN is_active SET DATA TYPE TEXT USING 'text'", - alter_stmt - ), - ); + let res = Parser::::parse_sql(&format!( + "{} ALTER COLUMN is_active SET DATA TYPE TEXT USING 'text'", + alter_stmt + )); assert_eq!( ParserError::ParserError("Expected end of statement, found: USING".to_string()), res.unwrap_err() diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index d933f0f25..c42ad66da 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -15,8 +15,10 @@ //! Test SQL syntax specific to Hive. The parser based on the generic dialect //! is also tested (on the inputs it can handle). +mod test_utils; + use sqlparser::dialect::HiveDialect; -use sqlparser::test_utils::*; +use test_utils::*; #[test] fn parse_table_create() { @@ -206,7 +208,5 @@ fn from_cte() { } fn hive() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(HiveDialect {})], - } + tested_dialects!(HiveDialect) } diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index c613b8b16..33b2ca545 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -119,12 +119,8 @@ fn parse_mssql_bin_literal() { } fn ms() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MsSqlDialect {})], - } + tested_dialects!(MsSqlDialect) } fn ms_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MsSqlDialect {}), Box::new(GenericDialect {})], - } + tested_dialects!(MsSqlDialect, GenericDialect) } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index f67d05c34..655e76dd8 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -578,13 +578,9 @@ fn parse_substring_in_select() { } fn mysql() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - } + tested_dialects!(MySqlDialect) } fn mysql_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})], - } + tested_dialects!(MySqlDialect, GenericDialect) } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 60d9c1cb4..5cff57274 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -835,13 +835,9 @@ fn parse_comments() { } fn pg() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {})], - } + tested_dialects!(PostgreSqlDialect) } fn pg_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {}), Box::new(GenericDialect {})], - } + tested_dialects!(PostgreSqlDialect, GenericDialect) } diff --git a/tests/sqlparser_regression.rs b/tests/sqlparser_regression.rs index e869e0932..fc119710e 100644 --- a/tests/sqlparser_regression.rs +++ b/tests/sqlparser_regression.rs @@ -24,8 +24,7 @@ macro_rules! tpch_tests { #[test] fn $name() { - let dialect = GenericDialect {}; - let res = Parser::parse_sql(&dialect, QUERIES[$value -1]); + let res = Parser::::parse_sql(QUERIES[$value -1]); assert!(res.is_ok()); } )* diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index c08632a15..b2e507eb6 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -37,8 +37,7 @@ fn test_snowflake_create_table() { #[test] fn test_snowflake_single_line_tokenize() { let sql = "CREATE TABLE# this is a comment \ntable_1"; - let dialect = SnowflakeDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, sql); + let mut tokenizer = Tokenizer::::new(sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -55,7 +54,7 @@ fn test_snowflake_single_line_tokenize() { assert_eq!(expected, tokens); let sql = "CREATE TABLE// this is a comment \ntable_1"; - let mut tokenizer = Tokenizer::new(&dialect, sql); + let mut tokenizer = Tokenizer::::new(sql); let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ @@ -145,13 +144,9 @@ fn test_single_table_in_parenthesis_with_alias() { } fn snowflake() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {})], - } + tested_dialects!(SnowflakeDialect) } fn snowflake_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {}), Box::new(GenericDialect {})], - } + tested_dialects!(SnowflakeDialect, GenericDialect) } diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 61436cb51..6ba42ab07 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -119,14 +119,9 @@ fn parse_create_sqlite_quote() { } fn sqlite() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SQLiteDialect {})], - } + tested_dialects!(SQLiteDialect) } fn sqlite_and_generic() -> TestedDialects { - TestedDialects { - // we don't have a separate SQLite dialect, so test only the generic dialect for now - dialects: vec![Box::new(SQLiteDialect {}), Box::new(GenericDialect {})], - } + tested_dialects!(SQLiteDialect, GenericDialect) } diff --git a/tests/sqpparser_clickhouse.rs b/tests/sqpparser_clickhouse.rs index 39c8b230d..fe46b6ae0 100644 --- a/tests/sqpparser_clickhouse.rs +++ b/tests/sqpparser_clickhouse.rs @@ -51,7 +51,5 @@ fn parse_map_access_expr() { } fn clickhouse() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(ClickHouseDialect {})], - } + tested_dialects!(ClickHouseDialect) } diff --git a/src/test_utils.rs b/tests/test_utils.rs similarity index 67% rename from src/test_utils.rs rename to tests/test_utils.rs index 27eba1408..872eaa7dc 100644 --- a/src/test_utils.rs +++ b/tests/test_utils.rs @@ -10,12 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -/// This module contains internal utilities used for testing the library. -/// While technically public, the library's users are not supposed to rely -/// on this module, as it will change without notice. -// -// Integration tests (i.e. everything under `tests/`) import this -// via `tests/test_utils/mod.rs`. +#![allow(dead_code)] #[cfg(not(feature = "std"))] use alloc::{ @@ -26,15 +21,60 @@ use alloc::{ }; use core::fmt::Debug; -use crate::ast::*; -use crate::dialect::*; -use crate::parser::{Parser, ParserError}; -use crate::tokenizer::Tokenizer; +use sqlparser::dialect::*; +use sqlparser::parser::{Parser, ParserError}; +use sqlparser::{ast::*, tokenizer::Token}; + +#[macro_export] +macro_rules! nest { + ($base:expr $(, $join:expr)*) => { + TableFactor::NestedJoin(Box::new(TableWithJoins { + relation: $base, + joins: vec![$(join($join)),*] + })) + }; +} + +pub trait Parse { + fn parse_statements(&mut self) -> Result, ParserError>; + fn parse_expr(&mut self) -> Result; + fn parse_object_name(&mut self) -> Result; + fn peek_token(&self) -> Token; + fn next_token(&mut self) -> Token; + fn prev_token(&mut self); +} + +impl Parse for Parser { + fn parse_statements(&mut self) -> Result, ParserError> { + self.parse_statements() + } + fn parse_expr(&mut self) -> Result { + self.parse_expr() + } + + fn parse_object_name(&mut self) -> Result { + self.parse_object_name() + } + + fn peek_token(&self) -> Token { + self.peek_token() + } + + fn next_token(&mut self) -> Token { + self.next_token() + } + + fn prev_token(&mut self) { + self.prev_token() + } +} + +type ParserConstructor = fn(tokens: &str) -> Box; /// Tests use the methods on this struct to invoke the parser on one or /// multiple dialects. pub struct TestedDialects { - pub dialects: Vec>, + pub dialects: Vec<(&'static str, ParserConstructor)>, } impl TestedDialects { @@ -42,15 +82,18 @@ impl TestedDialects { /// return the same result, and return that result. pub fn one_of_identical_results(&self, f: F) -> T where - F: Fn(&dyn Dialect) -> T, + F: Fn(&ParserConstructor) -> T, { - let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect))); + let parse_results = self + .dialects + .iter() + .map(|(name, dialect)| (name, f(dialect))); parse_results .fold(None, |s, (dialect, parsed)| { if let Some((prev_dialect, prev_parsed)) = s { assert_eq!( prev_parsed, parsed, - "Parse results with {:?} are different from {:?}", + "Parse results with {} are different from {}", prev_dialect, dialect ); } @@ -62,17 +105,13 @@ impl TestedDialects { pub fn run_parser_method(&self, sql: &str, f: F) -> T where - F: Fn(&mut Parser) -> T, + F: Fn(&mut dyn Parse) -> T, { - self.one_of_identical_results(|dialect| { - let mut tokenizer = Tokenizer::new(dialect, sql); - let tokens = tokenizer.tokenize().unwrap(); - f(&mut Parser::new(tokens, dialect)) - }) + self.one_of_identical_results(|constructor| f(constructor(sql).as_mut())) } pub fn parse_sql_statements(&self, sql: &str) -> Result, ParserError> { - self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql)) + self.one_of_identical_results(|constructor| constructor(sql).parse_statements()) // To fail the `ensure_multiple_dialects_are_tested` test: // Parser::parse_sql(&**self.dialects.first().unwrap(), sql) } @@ -132,17 +171,32 @@ impl TestedDialects { } } +#[macro_export] +macro_rules! tested_dialects { + ($($dialect:ident),+) => { + TestedDialects { + dialects: vec![ + $( + (stringify!($dialect), |input| { + Box::new(sqlparser::parser::Parser::<$dialect>::new( + sqlparser::tokenizer::Tokenizer::<$dialect>::new(input).tokenize().unwrap(), + )) + }) + ),+ + ], + } + }; +} + pub fn all_dialects() -> TestedDialects { - TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(HiveDialect {}), - ], - } + tested_dialects!( + GenericDialect, + PostgreSqlDialect, + MsSqlDialect, + AnsiDialect, + SnowflakeDialect, + HiveDialect + ) } pub fn only(v: impl IntoIterator) -> T { diff --git a/tests/test_utils/mod.rs b/tests/test_utils/mod.rs deleted file mode 100644 index f224314b9..000000000 --- a/tests/test_utils/mod.rs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Re-export everything from `src/test_utils.rs`. -pub use sqlparser::test_utils::*; - -// For the test-only macros we take a different approach of keeping them here -// rather than in the library crate. -// -// This is because we don't need any of them to be shared between the -// integration tests (i.e. `tests/*`) and the unit tests (i.e. `src/*`), -// but also because Rust doesn't scope macros to a particular module -// (and while we export internal helpers as sqlparser::test_utils::<...>, -// expecting our users to abstain from relying on them, exporting internal -// macros at the top level, like `sqlparser::nest` was deemed too confusing). - -#[macro_export] -macro_rules! nest { - ($base:expr $(, $join:expr)*) => { - TableFactor::NestedJoin(Box::new(TableWithJoins { - relation: $base, - joins: vec![$(join($join)),*] - })) - }; -}