diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 53bb891de..9fe579a40 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -89,6 +89,15 @@ macro_rules! dialect_of { /// /// [module level documentation]: crate pub trait Dialect: Debug + Any { + /// Determine the [`TypeId`] of this dialect. + /// + /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be overriden + /// by dialects that behave like other dialects + /// (for example when wrapping a dialect). + fn dialect(&self) -> TypeId { + self.type_id() + } + /// Determine if a character starts a quoted identifier. The default /// implementation, accepting "double quoted" ids is both ANSI-compliant /// and appropriate for most dialects (with the notable exception of @@ -164,7 +173,7 @@ impl dyn Dialect { #[inline] pub fn is(&self) -> bool { // borrowed from `Any` implementation - TypeId::of::() == self.type_id() + TypeId::of::() == self.dialect() } } @@ -248,4 +257,98 @@ mod tests { fn parse_dialect(v: &str) -> Box { dialect_from_str(v).unwrap() } + + #[test] + fn parse_with_wrapped_dialect() { + /// Wrapper for a dialect. In a real-world example, this wrapper + /// would tweak the behavior of the dialect. For the test case, + /// it wraps all methods unaltered. + #[derive(Debug)] + struct WrappedDialect(MySqlDialect); + + impl Dialect for WrappedDialect { + fn dialect(&self) -> std::any::TypeId { + self.0.dialect() + } + + fn is_identifier_start(&self, ch: char) -> bool { + self.0.is_identifier_start(ch) + } + + fn is_delimited_identifier_start(&self, ch: char) -> bool { + self.0.is_delimited_identifier_start(ch) + } + + fn is_proper_identifier_inside_quotes( + &self, + chars: std::iter::Peekable>, + ) -> bool { + self.0.is_proper_identifier_inside_quotes(chars) + } + + fn supports_filter_during_aggregation(&self) -> bool { + self.0.supports_filter_during_aggregation() + } + + fn supports_within_after_array_aggregation(&self) -> bool { + self.0.supports_within_after_array_aggregation() + } + + fn supports_group_by_expr(&self) -> bool { + self.0.supports_group_by_expr() + } + + fn supports_substring_from_for_expr(&self) -> bool { + self.0.supports_substring_from_for_expr() + } + + fn supports_in_empty_list(&self) -> bool { + self.0.supports_in_empty_list() + } + + fn convert_type_before_value(&self) -> bool { + self.0.convert_type_before_value() + } + + fn parse_prefix( + &self, + parser: &mut sqlparser::parser::Parser, + ) -> Option> { + self.0.parse_prefix(parser) + } + + fn parse_infix( + &self, + parser: &mut sqlparser::parser::Parser, + expr: &Expr, + precedence: u8, + ) -> Option> { + self.0.parse_infix(parser, expr, precedence) + } + + fn get_next_precedence( + &self, + parser: &sqlparser::parser::Parser, + ) -> Option> { + self.0.get_next_precedence(parser) + } + + fn parse_statement( + &self, + parser: &mut sqlparser::parser::Parser, + ) -> Option> { + self.0.parse_statement(parser) + } + + fn is_identifier_part(&self, ch: char) -> bool { + self.0.is_identifier_part(ch) + } + } + + let statement = r#"SELECT 'Wayne\'s World'"#; + let res1 = Parser::parse_sql(&MySqlDialect {}, statement); + let res2 = Parser::parse_sql(&WrappedDialect(MySqlDialect {}), statement); + assert!(res1.is_ok()); + assert_eq!(res1, res2); + } }