Skip to content

Commit 72559e9

Browse files
authored
Add ability for dialects to override prefix, infix, and statement parsing (#581)
1 parent 7c02477 commit 72559e9

File tree

5 files changed

+239
-37
lines changed

5 files changed

+239
-37
lines changed

src/dialect/mod.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod redshift;
2222
mod snowflake;
2323
mod sqlite;
2424

25+
use crate::ast::{Expr, Statement};
2526
use core::any::{Any, TypeId};
2627
use core::fmt::Debug;
2728
use core::iter::Peekable;
@@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect;
3940
pub use self::snowflake::SnowflakeDialect;
4041
pub use self::sqlite::SQLiteDialect;
4142
pub use crate::keywords;
43+
use crate::parser::{Parser, ParserError};
4244

4345
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
4446
/// to `true` if `parser.dialect` is one of the `Dialect`s specified.
@@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any {
6567
fn is_identifier_start(&self, ch: char) -> bool;
6668
/// Determine if a character is a valid unquoted identifier character
6769
fn is_identifier_part(&self, ch: char) -> bool;
70+
/// Dialect-specific prefix parser override
71+
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
72+
// return None to fall back to the default behavior
73+
None
74+
}
75+
/// Dialect-specific infix parser override
76+
fn parse_infix(
77+
&self,
78+
_parser: &mut Parser,
79+
_expr: &Expr,
80+
_precendence: u8,
81+
) -> Option<Result<Expr, ParserError>> {
82+
// return None to fall back to the default behavior
83+
None
84+
}
85+
/// Dialect-specific precedence override
86+
fn get_next_precedence(&self, _parser: &Parser) -> Option<Result<u8, ParserError>> {
87+
// return None to fall back to the default behavior
88+
None
89+
}
90+
/// Dialect-specific statement parser override
91+
fn parse_statement(&self, _parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
92+
// return None to fall back to the default behavior
93+
None
94+
}
6895
}
6996

7097
impl dyn Dialect {

src/dialect/postgresql.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
// See the License for the specific language governing permissions and
1111
// limitations under the License.
1212

13+
use crate::ast::{CommentObject, Statement};
1314
use crate::dialect::Dialect;
15+
use crate::keywords::Keyword;
16+
use crate::parser::{Parser, ParserError};
17+
use crate::tokenizer::Token;
1418

1519
#[derive(Debug)]
1620
pub struct PostgreSqlDialect {}
@@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect {
3034
|| ch == '$'
3135
|| ch == '_'
3236
}
37+
38+
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
39+
if parser.parse_keyword(Keyword::COMMENT) {
40+
Some(parse_comment(parser))
41+
} else {
42+
None
43+
}
44+
}
45+
}
46+
47+
pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
48+
parser.expect_keyword(Keyword::ON)?;
49+
let token = parser.next_token();
50+
51+
let (object_type, object_name) = match token {
52+
Token::Word(w) if w.keyword == Keyword::COLUMN => {
53+
let object_name = parser.parse_object_name()?;
54+
(CommentObject::Column, object_name)
55+
}
56+
Token::Word(w) if w.keyword == Keyword::TABLE => {
57+
let object_name = parser.parse_object_name()?;
58+
(CommentObject::Table, object_name)
59+
}
60+
_ => parser.expected("comment object_type", token)?,
61+
};
62+
63+
parser.expect_keyword(Keyword::IS)?;
64+
let comment = if parser.parse_keyword(Keyword::NULL) {
65+
None
66+
} else {
67+
Some(parser.parse_literal_string()?)
68+
};
69+
Ok(Statement::Comment {
70+
object_type,
71+
object_name,
72+
comment,
73+
})
3374
}

src/dialect/sqlite.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
// See the License for the specific language governing permissions and
1111
// limitations under the License.
1212

13+
use crate::ast::Statement;
1314
use crate::dialect::Dialect;
15+
use crate::keywords::Keyword;
16+
use crate::parser::{Parser, ParserError};
1417

1518
#[derive(Debug)]
1619
pub struct SQLiteDialect {}
@@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect {
3538
fn is_identifier_part(&self, ch: char) -> bool {
3639
self.is_identifier_start(ch) || ('0'..='9').contains(&ch)
3740
}
41+
42+
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
43+
if parser.parse_keyword(Keyword::REPLACE) {
44+
parser.prev_token();
45+
Some(parser.parse_insert())
46+
} else {
47+
None
48+
}
49+
}
3850
}

src/parser.rs

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ impl<'a> Parser<'a> {
152152
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
153153
/// stopping before the statement separator, if any.
154154
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
155+
// allow the dialect to override statement parsing
156+
if let Some(statement) = self.dialect.parse_statement(self) {
157+
return statement;
158+
}
159+
155160
match self.next_token() {
156161
Token::Word(w) => match w.keyword {
157162
Keyword::KILL => Ok(self.parse_kill()?),
@@ -195,13 +200,6 @@ impl<'a> Parser<'a> {
195200
Keyword::EXECUTE => Ok(self.parse_execute()?),
196201
Keyword::PREPARE => Ok(self.parse_prepare()?),
197202
Keyword::MERGE => Ok(self.parse_merge()?),
198-
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => {
199-
self.prev_token();
200-
Ok(self.parse_insert()?)
201-
}
202-
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => {
203-
Ok(self.parse_comment()?)
204-
}
205203
_ => self.expected("an SQL statement", Token::Word(w)),
206204
},
207205
Token::LParen => {
@@ -381,6 +379,11 @@ impl<'a> Parser<'a> {
381379

382380
/// Parse an expression prefix
383381
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
382+
// allow the dialect to override prefix parsing
383+
if let Some(prefix) = self.dialect.parse_prefix(self) {
384+
return prefix;
385+
}
386+
384387
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
385388
// string literal represents a literal of that type. Some examples:
386389
//
@@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> {
11641167

11651168
/// Parse an operator following an expression
11661169
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> {
1170+
// allow the dialect to override infix parsing
1171+
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) {
1172+
return infix;
1173+
}
1174+
11671175
let tok = self.next_token();
11681176

11691177
let regular_binary_operator = match &tok {
@@ -1491,6 +1499,11 @@ impl<'a> Parser<'a> {
14911499

14921500
/// Get the precedence of the next token
14931501
pub fn get_next_precedence(&self) -> Result<u8, ParserError> {
1502+
// allow the dialect to override precedence logic
1503+
if let Some(precedence) = self.dialect.get_next_precedence(self) {
1504+
return precedence;
1505+
}
1506+
14941507
let token = self.peek_token();
14951508
debug!("get_next_precedence() {:?}", token);
14961509
let token_0 = self.peek_nth_token(0);
@@ -1618,7 +1631,7 @@ impl<'a> Parser<'a> {
16181631
}
16191632

16201633
/// Report unexpected token
1621-
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
1634+
pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
16221635
parser_err!(format!("Expected {}, found: {}", expected, found))
16231636
}
16241637

@@ -4735,35 +4748,6 @@ impl<'a> Parser<'a> {
47354748
})
47364749
}
47374750

4738-
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
4739-
self.expect_keyword(Keyword::ON)?;
4740-
let token = self.next_token();
4741-
4742-
let (object_type, object_name) = match token {
4743-
Token::Word(w) if w.keyword == Keyword::COLUMN => {
4744-
let object_name = self.parse_object_name()?;
4745-
(CommentObject::Column, object_name)
4746-
}
4747-
Token::Word(w) if w.keyword == Keyword::TABLE => {
4748-
let object_name = self.parse_object_name()?;
4749-
(CommentObject::Table, object_name)
4750-
}
4751-
_ => self.expected("comment object_type", token)?,
4752-
};
4753-
4754-
self.expect_keyword(Keyword::IS)?;
4755-
let comment = if self.parse_keyword(Keyword::NULL) {
4756-
None
4757-
} else {
4758-
Some(self.parse_literal_string()?)
4759-
};
4760-
Ok(Statement::Comment {
4761-
object_type,
4762-
object_name,
4763-
comment,
4764-
})
4765-
}
4766-
47674751
pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> {
47684752
let mut clauses: Vec<MergeClause> = vec![];
47694753
loop {

tests/sqlparser_custom_dialect.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Licensed under the Apache License, Version 2.0 (the "License");
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software
8+
// distributed under the License is distributed on an "AS IS" BASIS,
9+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
// See the License for the specific language governing permissions and
11+
// limitations under the License.
12+
13+
//! Test the ability for dialects to override parsing
14+
15+
use sqlparser::{
16+
ast::{BinaryOperator, Expr, Statement, Value},
17+
dialect::Dialect,
18+
keywords::Keyword,
19+
parser::{Parser, ParserError},
20+
tokenizer::Token,
21+
};
22+
23+
#[test]
24+
fn custom_prefix_parser() -> Result<(), ParserError> {
25+
#[derive(Debug)]
26+
struct MyDialect {}
27+
28+
impl Dialect for MyDialect {
29+
fn is_identifier_start(&self, ch: char) -> bool {
30+
is_identifier_start(ch)
31+
}
32+
33+
fn is_identifier_part(&self, ch: char) -> bool {
34+
is_identifier_part(ch)
35+
}
36+
37+
fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
38+
if parser.consume_token(&Token::Number("1".to_string(), false)) {
39+
Some(Ok(Expr::Value(Value::Null)))
40+
} else {
41+
None
42+
}
43+
}
44+
}
45+
46+
let dialect = MyDialect {};
47+
let sql = "SELECT 1 + 2";
48+
let ast = Parser::parse_sql(&dialect, sql)?;
49+
let query = &ast[0];
50+
assert_eq!("SELECT NULL + 2", &format!("{}", query));
51+
Ok(())
52+
}
53+
54+
#[test]
55+
fn custom_infix_parser() -> Result<(), ParserError> {
56+
#[derive(Debug)]
57+
struct MyDialect {}
58+
59+
impl Dialect for MyDialect {
60+
fn is_identifier_start(&self, ch: char) -> bool {
61+
is_identifier_start(ch)
62+
}
63+
64+
fn is_identifier_part(&self, ch: char) -> bool {
65+
is_identifier_part(ch)
66+
}
67+
68+
fn parse_infix(
69+
&self,
70+
parser: &mut Parser,
71+
expr: &Expr,
72+
_precendence: u8,
73+
) -> Option<Result<Expr, ParserError>> {
74+
if parser.consume_token(&Token::Plus) {
75+
Some(Ok(Expr::BinaryOp {
76+
left: Box::new(expr.clone()),
77+
op: BinaryOperator::Multiply, // translate Plus to Multiply
78+
right: Box::new(parser.parse_expr().unwrap()),
79+
}))
80+
} else {
81+
None
82+
}
83+
}
84+
}
85+
86+
let dialect = MyDialect {};
87+
let sql = "SELECT 1 + 2";
88+
let ast = Parser::parse_sql(&dialect, sql)?;
89+
let query = &ast[0];
90+
assert_eq!("SELECT 1 * 2", &format!("{}", query));
91+
Ok(())
92+
}
93+
94+
#[test]
95+
fn custom_statement_parser() -> Result<(), ParserError> {
96+
#[derive(Debug)]
97+
struct MyDialect {}
98+
99+
impl Dialect for MyDialect {
100+
fn is_identifier_start(&self, ch: char) -> bool {
101+
is_identifier_start(ch)
102+
}
103+
104+
fn is_identifier_part(&self, ch: char) -> bool {
105+
is_identifier_part(ch)
106+
}
107+
108+
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
109+
if parser.parse_keyword(Keyword::SELECT) {
110+
for _ in 0..3 {
111+
let _ = parser.next_token();
112+
}
113+
Some(Ok(Statement::Commit { chain: false }))
114+
} else {
115+
None
116+
}
117+
}
118+
}
119+
120+
let dialect = MyDialect {};
121+
let sql = "SELECT 1 + 2";
122+
let ast = Parser::parse_sql(&dialect, sql)?;
123+
let query = &ast[0];
124+
assert_eq!("COMMIT", &format!("{}", query));
125+
Ok(())
126+
}
127+
128+
fn is_identifier_start(ch: char) -> bool {
129+
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
130+
}
131+
132+
fn is_identifier_part(ch: char) -> bool {
133+
('a'..='z').contains(&ch)
134+
|| ('A'..='Z').contains(&ch)
135+
|| ('0'..='9').contains(&ch)
136+
|| ch == '$'
137+
|| ch == '_'
138+
}

0 commit comments

Comments
 (0)