Skip to content

Commit 3fa1b1d

Browse files
committed
Properly handle mixed implicit and explicit joins
Parse a query like SELECT * FROM a NATURAL JOIN b, c NATURAL JOIN d as the SQL specification requires, i.e.: from: [ TableReference { relation: TableFactor::Table("a"), joins: [Join { relation: TableFactor::Table("b"), join_operator: JoinOperator::Natural, }] }, TableReference { relation: TableFactor::Table("c"), joins: [Join { relation: TableFactor::Table("d"), join_operator: JoinOperator::Natural, }] } ] Previously we were parsing such queries as relation: TableFactor::Table("a"), joins: [ Join { relation: TableFactor::Table("b"), join_operator: JoinOperator::Natural, }, Join { relation: TableFactor::Table("c"), join_operator: JoinOperator::Implicit, }, Join { relation: TableFactor::Table("d"), join_operator: JoinOperator::Natural, }, ] which did not make the join hierarchy clear.
1 parent 518c883 commit 3fa1b1d

File tree

6 files changed

+205
-147
lines changed

6 files changed

+205
-147
lines changed

src/sqlast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub use self::ddl::{
2727
};
2828
pub use self::query::{
2929
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
30-
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor,
30+
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableAndJoins,
3131
};
3232
pub use self::sqltype::SQLType;
3333
pub use self::value::{SQLDateTimeField, Value};

src/sqlast/query.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ pub struct SQLSelect {
113113
/// projection expressions
114114
pub projection: Vec<SQLSelectItem>,
115115
/// FROM
116-
pub relation: Option<TableFactor>,
117-
/// JOIN
118-
pub joins: Vec<Join>,
116+
pub from: Vec<TableAndJoins>,
119117
/// WHERE
120118
pub selection: Option<ASTNode>,
121119
/// GROUP BY
@@ -131,11 +129,8 @@ impl ToString for SQLSelect {
131129
if self.distinct { " DISTINCT" } else { "" },
132130
comma_separated_string(&self.projection)
133131
);
134-
if let Some(ref relation) = self.relation {
135-
s += &format!(" FROM {}", relation.to_string());
136-
}
137-
for join in &self.joins {
138-
s += &join.to_string();
132+
if !self.from.is_empty() {
133+
s += &format!(" FROM {}", comma_separated_string(&self.from));
139134
}
140135
if let Some(ref selection) = self.selection {
141136
s += &format!(" WHERE {}", selection.to_string());
@@ -197,6 +192,22 @@ impl ToString for SQLSelectItem {
197192
}
198193
}
199194

195+
#[derive(Debug, Clone, PartialEq, Hash)]
196+
pub struct TableAndJoins {
197+
pub relation: TableFactor,
198+
pub joins: Vec<Join>,
199+
}
200+
201+
impl ToString for TableAndJoins {
202+
fn to_string(&self) -> String {
203+
let mut s = self.relation.to_string();
204+
for join in &self.joins {
205+
s += &join.to_string();
206+
}
207+
s
208+
}
209+
}
210+
200211
/// A table name or a parenthesized subquery with an optional alias
201212
#[derive(Debug, Clone, PartialEq, Hash)]
202213
pub enum TableFactor {
@@ -215,10 +226,7 @@ pub enum TableFactor {
215226
subquery: Box<SQLQuery>,
216227
alias: Option<TableAlias>,
217228
},
218-
NestedJoin {
219-
base: Box<TableFactor>,
220-
joins: Vec<Join>,
221-
},
229+
NestedJoin(Box<TableAndJoins>),
222230
}
223231

224232
impl ToString for TableFactor {
@@ -257,12 +265,8 @@ impl ToString for TableFactor {
257265
}
258266
s
259267
}
260-
TableFactor::NestedJoin { base, joins } => {
261-
let mut s = base.to_string();
262-
for join in joins {
263-
s += &join.to_string();
264-
}
265-
format!("({})", s)
268+
TableFactor::NestedJoin(table_reference) => {
269+
format!("({})", table_reference.to_string())
266270
}
267271
}
268272
}
@@ -313,7 +317,6 @@ impl ToString for Join {
313317
suffix(constraint)
314318
),
315319
JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()),
316-
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
317320
JoinOperator::LeftOuter(constraint) => format!(
318321
" {}LEFT JOIN {}{}",
319322
prefix(constraint),
@@ -342,7 +345,6 @@ pub enum JoinOperator {
342345
LeftOuter(JoinConstraint),
343346
RightOuter(JoinConstraint),
344347
FullOuter(JoinConstraint),
345-
Implicit,
346348
Cross,
347349
}
348350

src/sqlparser.rs

Lines changed: 68 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,13 +1570,15 @@ impl Parser {
15701570
}
15711571
let projection = self.parse_select_list()?;
15721572

1573-
let (relation, joins) = if self.parse_keyword("FROM") {
1574-
let relation = Some(self.parse_table_factor()?);
1575-
let joins = self.parse_joins()?;
1576-
(relation, joins)
1577-
} else {
1578-
(None, vec![])
1579-
};
1573+
let mut from = vec![];
1574+
if self.parse_keyword("FROM") {
1575+
loop {
1576+
from.push(self.parse_table_and_joins()?);
1577+
if !self.consume_token(&Token::Comma) {
1578+
break;
1579+
}
1580+
}
1581+
}
15801582

15811583
let selection = if self.parse_keyword("WHERE") {
15821584
Some(self.parse_expr()?)
@@ -1599,14 +1601,69 @@ impl Parser {
15991601
Ok(SQLSelect {
16001602
distinct,
16011603
projection,
1604+
from,
16021605
selection,
1603-
relation,
1604-
joins,
16051606
group_by,
16061607
having,
16071608
})
16081609
}
16091610

1611+
pub fn parse_table_and_joins(&mut self) -> Result<TableAndJoins, ParserError> {
1612+
let relation = self.parse_table_factor()?;
1613+
let mut joins = vec![];
1614+
loop {
1615+
let join = match &self.peek_token() {
1616+
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
1617+
self.next_token();
1618+
self.expect_keyword("JOIN")?;
1619+
Join {
1620+
relation: self.parse_table_factor()?,
1621+
join_operator: JoinOperator::Cross,
1622+
}
1623+
}
1624+
_ => {
1625+
let natural = self.parse_keyword("NATURAL");
1626+
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
1627+
kw.keyword
1628+
} else {
1629+
String::default()
1630+
};
1631+
1632+
let join_operator_type = match peek_keyword.as_ref() {
1633+
"INNER" | "JOIN" => {
1634+
let _ = self.parse_keyword("INNER");
1635+
self.expect_keyword("JOIN")?;
1636+
JoinOperator::Inner
1637+
}
1638+
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
1639+
let _ = self.next_token();
1640+
let _ = self.parse_keyword("OUTER");
1641+
self.expect_keyword("JOIN")?;
1642+
match kw {
1643+
"LEFT" => JoinOperator::LeftOuter,
1644+
"RIGHT" => JoinOperator::RightOuter,
1645+
"FULL" => JoinOperator::FullOuter,
1646+
_ => unreachable!(),
1647+
}
1648+
}
1649+
_ if natural => {
1650+
return self.expected("a join type after NATURAL", self.peek_token());
1651+
}
1652+
_ => break,
1653+
};
1654+
let relation = self.parse_table_factor()?;
1655+
let join_constraint = self.parse_join_constraint(natural)?;
1656+
Join {
1657+
relation,
1658+
join_operator: join_operator_type(join_constraint),
1659+
}
1660+
}
1661+
};
1662+
joins.push(join);
1663+
}
1664+
Ok(TableAndJoins { relation, joins })
1665+
}
1666+
16101667
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
16111668
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
16121669
let lateral = self.parse_keyword("LATERAL");
@@ -1627,10 +1684,9 @@ impl Parser {
16271684
} else if lateral {
16281685
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
16291686
} else {
1630-
let base = Box::new(self.parse_table_factor()?);
1631-
let joins = self.parse_joins()?;
1687+
let table_reference = self.parse_table_and_joins()?;
16321688
self.expect_token(&Token::RParen)?;
1633-
Ok(TableFactor::NestedJoin { base, joins })
1689+
Ok(TableFactor::NestedJoin(Box::new(table_reference)))
16341690
}
16351691
} else if lateral {
16361692
self.expected("subquery after LATERAL", self.peek_token())
@@ -1677,68 +1733,6 @@ impl Parser {
16771733
}
16781734
}
16791735

1680-
fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
1681-
let mut joins = vec![];
1682-
loop {
1683-
let join = match &self.peek_token() {
1684-
Some(Token::Comma) => {
1685-
self.next_token();
1686-
Join {
1687-
relation: self.parse_table_factor()?,
1688-
join_operator: JoinOperator::Implicit,
1689-
}
1690-
}
1691-
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
1692-
self.next_token();
1693-
self.expect_keyword("JOIN")?;
1694-
Join {
1695-
relation: self.parse_table_factor()?,
1696-
join_operator: JoinOperator::Cross,
1697-
}
1698-
}
1699-
_ => {
1700-
let natural = self.parse_keyword("NATURAL");
1701-
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
1702-
kw.keyword
1703-
} else {
1704-
String::default()
1705-
};
1706-
1707-
let join_operator_type = match peek_keyword.as_ref() {
1708-
"INNER" | "JOIN" => {
1709-
let _ = self.parse_keyword("INNER");
1710-
self.expect_keyword("JOIN")?;
1711-
JoinOperator::Inner
1712-
}
1713-
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
1714-
let _ = self.next_token();
1715-
let _ = self.parse_keyword("OUTER");
1716-
self.expect_keyword("JOIN")?;
1717-
match kw {
1718-
"LEFT" => JoinOperator::LeftOuter,
1719-
"RIGHT" => JoinOperator::RightOuter,
1720-
"FULL" => JoinOperator::FullOuter,
1721-
_ => unreachable!(),
1722-
}
1723-
}
1724-
_ if natural => {
1725-
return self.expected("a join type after NATURAL", self.peek_token());
1726-
}
1727-
_ => break,
1728-
};
1729-
let relation = self.parse_table_factor()?;
1730-
let join_constraint = self.parse_join_constraint(natural)?;
1731-
Join {
1732-
relation,
1733-
join_operator: join_operator_type(join_constraint),
1734-
}
1735-
}
1736-
};
1737-
joins.push(join);
1738-
}
1739-
Ok(joins)
1740-
}
1741-
17421736
/// Parse an INSERT statement
17431737
pub fn parse_insert(&mut self) -> Result<SQLStatement, ParserError> {
17441738
self.expect_keyword("INTO")?;

src/test_utils.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects {
109109
}
110110
}
111111

112-
pub fn only<T>(v: &[T]) -> &T {
113-
assert_eq!(1, v.len());
114-
v.first().unwrap()
112+
pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
113+
let mut iter = v.into_iter();
114+
if let (Some(item), None) = (iter.next(), iter.next()) {
115+
item
116+
} else {
117+
panic!("only called on collection without exactly one item")
118+
}
115119
}
116120

117121
pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode {

0 commit comments

Comments
 (0)