Skip to content

Commit 707c58a

Browse files
committed
Support parsing of multiple statements (5/5)
Parser::parse_sql() can now parse a semicolon-separated list of statements, returning them in a Vec<SQLStatement>. To support this we: - Move handling of inter-statement tokens from the end of individual statement parsers (`parse_select` and `parse_delete`; this was not implemented for other top-level statements) to the common statement-list parsing code (`parse_sql`); - Change the "Unexpected token at end of ..." error, which didn't have tests and prevented us from parsing successive statements -> "Expected end of statement" (i.e. a delimiter - currently only ";" - or the EOF); - Add PartialEq on ParserError to be able to assert_eq!() that parsing statements that do not terminate properly returns an expected error.
1 parent 5a0e0ec commit 707c58a

File tree

3 files changed

+91
-71
lines changed

3 files changed

+91
-71
lines changed

src/sqlparser.rs

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use super::sqlast::*;
2020
use super::sqltokenizer::*;
2121
use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2222

23-
#[derive(Debug, Clone)]
23+
#[derive(Debug, Clone, PartialEq)]
2424
pub enum ParserError {
2525
TokenizerError(String),
2626
ParserError(String),
@@ -54,14 +54,36 @@ impl Parser {
5454
}
5555

5656
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
57-
pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<SQLStatement, ParserError> {
57+
pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<Vec<SQLStatement>, ParserError> {
5858
let mut tokenizer = Tokenizer::new(dialect, &sql);
5959
let tokens = tokenizer.tokenize()?;
6060
let mut parser = Parser::new(tokens);
61-
parser.parse_statement()
61+
let mut stmts = Vec::new();
62+
let mut expecting_statement_delimiter = false;
63+
loop {
64+
// ignore empty statements (between successive statement delimiters)
65+
while parser.consume_token(&Token::SemiColon) {
66+
expecting_statement_delimiter = false;
67+
}
68+
69+
if parser.peek_token().is_none() {
70+
break;
71+
} else if expecting_statement_delimiter {
72+
return parser_err!(format!(
73+
"Expected end of statement, found: {}",
74+
parser.peek_token().unwrap().to_string()
75+
));
76+
}
77+
78+
let statement = parser.parse_statement()?;
79+
stmts.push(statement);
80+
expecting_statement_delimiter = true;
81+
}
82+
Ok(stmts)
6283
}
6384

64-
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.)
85+
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
86+
/// stopping before the statement separator, if any.
6587
pub fn parse_statement(&mut self) -> Result<SQLStatement, ParserError> {
6688
match self.next_token() {
6789
Some(t) => match t {
@@ -1095,20 +1117,10 @@ impl Parser {
10951117
None
10961118
};
10971119

1098-
let _ = self.consume_token(&Token::SemiColon);
1099-
1100-
// parse next token
1101-
if let Some(next_token) = self.peek_token() {
1102-
parser_err!(format!(
1103-
"Unexpected token at end of DELETE: {:?}",
1104-
next_token
1105-
))
1106-
} else {
1107-
Ok(SQLStatement::SQLDelete {
1108-
relation,
1109-
selection,
1110-
})
1111-
}
1120+
Ok(SQLStatement::SQLDelete {
1121+
relation,
1122+
selection,
1123+
})
11121124
}
11131125

11141126
/// Parse a SELECT statement
@@ -1154,25 +1166,16 @@ impl Parser {
11541166
None
11551167
};
11561168

1157-
let _ = self.consume_token(&Token::SemiColon);
1158-
1159-
if let Some(next_token) = self.peek_token() {
1160-
parser_err!(format!(
1161-
"Unexpected token at end of SELECT: {:?}",
1162-
next_token
1163-
))
1164-
} else {
1165-
Ok(SQLSelect {
1166-
projection,
1167-
selection,
1168-
relation,
1169-
joins,
1170-
limit,
1171-
order_by,
1172-
group_by,
1173-
having,
1174-
})
1175-
}
1169+
Ok(SQLSelect {
1170+
projection,
1171+
selection,
1172+
relation,
1173+
joins,
1174+
limit,
1175+
order_by,
1176+
group_by,
1177+
having,
1178+
})
11761179
}
11771180

11781181
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`

tests/sqlparser_generic.rs

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -473,34 +473,6 @@ fn parse_case_expression() {
473473
);
474474
}
475475

476-
#[test]
477-
fn parse_select_with_semi_colon() {
478-
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1;");
479-
match one_statement_parses_to(&sql, "") {
480-
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
481-
assert_eq!(3, projection.len());
482-
}
483-
_ => assert!(false),
484-
}
485-
}
486-
487-
#[test]
488-
fn parse_delete_with_semi_colon() {
489-
let sql: &str = "DELETE FROM 'table';";
490-
491-
match one_statement_parses_to(&sql, "") {
492-
SQLStatement::SQLDelete { relation, .. } => {
493-
assert_eq!(
494-
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
495-
"table".to_string()
496-
)))),
497-
relation
498-
);
499-
}
500-
_ => assert!(false),
501-
}
502-
}
503-
504476
#[test]
505477
fn parse_implicit_join() {
506478
let sql = "SELECT * FROM t1, t2";
@@ -669,6 +641,37 @@ fn parse_join_syntax_variants() {
669641
);
670642
}
671643

644+
#[test]
645+
fn parse_multiple_statements() {
646+
fn test_with(sql1: &str, sql2_kw: &str, sql2_rest: &str) {
647+
// Check that a string consisting of two statements delimited by a semicolon
648+
// parses the same as both statements individually:
649+
let res = parse_sql_statements(&(sql1.to_owned() + ";" + sql2_kw + sql2_rest));
650+
assert_eq!(
651+
vec![
652+
one_statement_parses_to(&sql1, ""),
653+
one_statement_parses_to(&(sql2_kw.to_owned() + sql2_rest), ""),
654+
],
655+
res.unwrap()
656+
);
657+
// Check that extra semicolon at the end is stripped by normalization:
658+
one_statement_parses_to(&(sql1.to_owned() + ";"), sql1);
659+
// Check that forgetting the semicolon results in an error:
660+
let res = parse_sql_statements(&(sql1.to_owned() + " " + sql2_kw + sql2_rest));
661+
assert_eq!(
662+
ParserError::ParserError("Expected end of statement, found: ".to_string() + sql2_kw),
663+
res.unwrap_err()
664+
);
665+
}
666+
test_with("SELECT foo", "SELECT", " bar");
667+
test_with("DELETE FROM foo", "SELECT", " bar");
668+
test_with("INSERT INTO foo VALUES(1)", "SELECT", " bar");
669+
test_with("CREATE TABLE foo (baz int)", "SELECT", " bar");
670+
// Make sure that empty statements do not cause an error:
671+
let res = parse_sql_statements(";;");
672+
assert_eq!(0, res.unwrap().len());
673+
}
674+
672675
fn only<'a, T>(v: &'a Vec<T>) -> &'a T {
673676
assert_eq!(1, v.len());
674677
v.first().unwrap()
@@ -699,17 +702,24 @@ fn verified_expr(query: &str) -> ASTNode {
699702
ast
700703
}
701704

702-
/// Ensures that `sql` parses as a statement, optionally checking that
705+
/// Ensures that `sql` parses as a single statement, optionally checking that
703706
/// converting AST back to string equals to `canonical` (unless an empty string
704707
/// is provided).
705708
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
706-
let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string()).unwrap();
707-
let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
708-
assert_eq!(generic_ast, pg_ast);
709+
let mut statements = parse_sql_statements(&sql).unwrap();
710+
assert_eq!(statements.len(), 1);
709711

712+
let only_statement = statements.pop().unwrap();
710713
if !canonical.is_empty() {
711-
assert_eq!(canonical, generic_ast.to_string())
714+
assert_eq!(canonical, only_statement.to_string())
712715
}
716+
only_statement
717+
}
718+
719+
fn parse_sql_statements(sql: &str) -> Result<Vec<SQLStatement>, ParserError> {
720+
let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string());
721+
let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string());
722+
assert_eq!(generic_ast, pg_ast);
713723
generic_ast
714724
}
715725

tests/sqlparser_postgres.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,20 @@ fn verified_stmt(query: &str) -> SQLStatement {
372372
/// converting AST back to string equals to `canonical` (unless an empty string
373373
/// is provided).
374374
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
375-
let only_statement = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
375+
let mut statements = parse_sql_statements(&sql).unwrap();
376+
assert_eq!(statements.len(), 1);
377+
378+
let only_statement = statements.pop().unwrap();
376379
if !canonical.is_empty() {
377380
assert_eq!(canonical, only_statement.to_string())
378381
}
379382
only_statement
380383
}
381384

385+
fn parse_sql_statements(sql: &str) -> Result<Vec<SQLStatement>, ParserError> {
386+
Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string())
387+
}
388+
382389
fn parse_sql_expr(sql: &str) -> ASTNode {
383390
debug!("sql: {}", sql);
384391
let mut parser = parser(sql);

0 commit comments

Comments
 (0)