Skip to content

Commit 5896f01

Browse files
authored
Merge pull request #109 from benesch/implicit-join-fix
Properly handle mixed implicit and explicit joins
2 parents 518c883 + ce171c2 commit 5896f01

File tree

6 files changed

+218
-151
lines changed

6 files changed

+218
-151
lines changed

src/sqlast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub use self::ddl::{
2727
};
2828
pub use self::query::{
2929
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
30-
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor,
30+
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins,
3131
};
3232
pub use self::sqltype::SQLType;
3333
pub use self::value::{SQLDateTimeField, Value};

src/sqlast/query.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ pub struct SQLSelect {
113113
/// projection expressions
114114
pub projection: Vec<SQLSelectItem>,
115115
/// FROM
116-
pub relation: Option<TableFactor>,
117-
/// JOIN
118-
pub joins: Vec<Join>,
116+
pub from: Vec<TableWithJoins>,
119117
/// WHERE
120118
pub selection: Option<ASTNode>,
121119
/// GROUP BY
@@ -131,11 +129,8 @@ impl ToString for SQLSelect {
131129
if self.distinct { " DISTINCT" } else { "" },
132130
comma_separated_string(&self.projection)
133131
);
134-
if let Some(ref relation) = self.relation {
135-
s += &format!(" FROM {}", relation.to_string());
136-
}
137-
for join in &self.joins {
138-
s += &join.to_string();
132+
if !self.from.is_empty() {
133+
s += &format!(" FROM {}", comma_separated_string(&self.from));
139134
}
140135
if let Some(ref selection) = self.selection {
141136
s += &format!(" WHERE {}", selection.to_string());
@@ -197,6 +192,22 @@ impl ToString for SQLSelectItem {
197192
}
198193
}
199194

195+
#[derive(Debug, Clone, PartialEq, Hash)]
196+
pub struct TableWithJoins {
197+
pub relation: TableFactor,
198+
pub joins: Vec<Join>,
199+
}
200+
201+
impl ToString for TableWithJoins {
202+
fn to_string(&self) -> String {
203+
let mut s = self.relation.to_string();
204+
for join in &self.joins {
205+
s += &join.to_string();
206+
}
207+
s
208+
}
209+
}
210+
200211
/// A table name or a parenthesized subquery with an optional alias
201212
#[derive(Debug, Clone, PartialEq, Hash)]
202213
pub enum TableFactor {
@@ -215,10 +226,7 @@ pub enum TableFactor {
215226
subquery: Box<SQLQuery>,
216227
alias: Option<TableAlias>,
217228
},
218-
NestedJoin {
219-
base: Box<TableFactor>,
220-
joins: Vec<Join>,
221-
},
229+
NestedJoin(Box<TableWithJoins>),
222230
}
223231

224232
impl ToString for TableFactor {
@@ -257,12 +265,8 @@ impl ToString for TableFactor {
257265
}
258266
s
259267
}
260-
TableFactor::NestedJoin { base, joins } => {
261-
let mut s = base.to_string();
262-
for join in joins {
263-
s += &join.to_string();
264-
}
265-
format!("({})", s)
268+
TableFactor::NestedJoin(table_reference) => {
269+
format!("({})", table_reference.to_string())
266270
}
267271
}
268272
}
@@ -313,7 +317,6 @@ impl ToString for Join {
313317
suffix(constraint)
314318
),
315319
JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()),
316-
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
317320
JoinOperator::LeftOuter(constraint) => format!(
318321
" {}LEFT JOIN {}{}",
319322
prefix(constraint),
@@ -342,7 +345,6 @@ pub enum JoinOperator {
342345
LeftOuter(JoinConstraint),
343346
RightOuter(JoinConstraint),
344347
FullOuter(JoinConstraint),
345-
Implicit,
346348
Cross,
347349
}
348350

src/sqlparser.rs

Lines changed: 73 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ impl Parser {
109109
match self.next_token() {
110110
Some(t) => match t {
111111
Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() {
112-
"SELECT" | "WITH" => {
112+
"SELECT" | "WITH" | "VALUES" => {
113113
self.prev_token();
114114
Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?)))
115115
}
@@ -133,6 +133,10 @@ impl Parser {
133133
w.to_string()
134134
)),
135135
},
136+
Token::LParen => {
137+
self.prev_token();
138+
Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?)))
139+
}
136140
unexpected => self.expected(
137141
"a keyword at the beginning of a statement",
138142
Some(unexpected),
@@ -1570,13 +1574,15 @@ impl Parser {
15701574
}
15711575
let projection = self.parse_select_list()?;
15721576

1573-
let (relation, joins) = if self.parse_keyword("FROM") {
1574-
let relation = Some(self.parse_table_factor()?);
1575-
let joins = self.parse_joins()?;
1576-
(relation, joins)
1577-
} else {
1578-
(None, vec![])
1579-
};
1577+
let mut from = vec![];
1578+
if self.parse_keyword("FROM") {
1579+
loop {
1580+
from.push(self.parse_table_and_joins()?);
1581+
if !self.consume_token(&Token::Comma) {
1582+
break;
1583+
}
1584+
}
1585+
}
15801586

15811587
let selection = if self.parse_keyword("WHERE") {
15821588
Some(self.parse_expr()?)
@@ -1599,14 +1605,69 @@ impl Parser {
15991605
Ok(SQLSelect {
16001606
distinct,
16011607
projection,
1608+
from,
16021609
selection,
1603-
relation,
1604-
joins,
16051610
group_by,
16061611
having,
16071612
})
16081613
}
16091614

1615+
pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {
1616+
let relation = self.parse_table_factor()?;
1617+
let mut joins = vec![];
1618+
loop {
1619+
let join = match &self.peek_token() {
1620+
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
1621+
self.next_token();
1622+
self.expect_keyword("JOIN")?;
1623+
Join {
1624+
relation: self.parse_table_factor()?,
1625+
join_operator: JoinOperator::Cross,
1626+
}
1627+
}
1628+
_ => {
1629+
let natural = self.parse_keyword("NATURAL");
1630+
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
1631+
kw.keyword
1632+
} else {
1633+
String::default()
1634+
};
1635+
1636+
let join_operator_type = match peek_keyword.as_ref() {
1637+
"INNER" | "JOIN" => {
1638+
let _ = self.parse_keyword("INNER");
1639+
self.expect_keyword("JOIN")?;
1640+
JoinOperator::Inner
1641+
}
1642+
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
1643+
let _ = self.next_token();
1644+
let _ = self.parse_keyword("OUTER");
1645+
self.expect_keyword("JOIN")?;
1646+
match kw {
1647+
"LEFT" => JoinOperator::LeftOuter,
1648+
"RIGHT" => JoinOperator::RightOuter,
1649+
"FULL" => JoinOperator::FullOuter,
1650+
_ => unreachable!(),
1651+
}
1652+
}
1653+
_ if natural => {
1654+
return self.expected("a join type after NATURAL", self.peek_token());
1655+
}
1656+
_ => break,
1657+
};
1658+
let relation = self.parse_table_factor()?;
1659+
let join_constraint = self.parse_join_constraint(natural)?;
1660+
Join {
1661+
relation,
1662+
join_operator: join_operator_type(join_constraint),
1663+
}
1664+
}
1665+
};
1666+
joins.push(join);
1667+
}
1668+
Ok(TableWithJoins { relation, joins })
1669+
}
1670+
16101671
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
16111672
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
16121673
let lateral = self.parse_keyword("LATERAL");
@@ -1627,10 +1688,9 @@ impl Parser {
16271688
} else if lateral {
16281689
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
16291690
} else {
1630-
let base = Box::new(self.parse_table_factor()?);
1631-
let joins = self.parse_joins()?;
1691+
let table_reference = self.parse_table_and_joins()?;
16321692
self.expect_token(&Token::RParen)?;
1633-
Ok(TableFactor::NestedJoin { base, joins })
1693+
Ok(TableFactor::NestedJoin(Box::new(table_reference)))
16341694
}
16351695
} else if lateral {
16361696
self.expected("subquery after LATERAL", self.peek_token())
@@ -1677,68 +1737,6 @@ impl Parser {
16771737
}
16781738
}
16791739

1680-
fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
1681-
let mut joins = vec![];
1682-
loop {
1683-
let join = match &self.peek_token() {
1684-
Some(Token::Comma) => {
1685-
self.next_token();
1686-
Join {
1687-
relation: self.parse_table_factor()?,
1688-
join_operator: JoinOperator::Implicit,
1689-
}
1690-
}
1691-
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
1692-
self.next_token();
1693-
self.expect_keyword("JOIN")?;
1694-
Join {
1695-
relation: self.parse_table_factor()?,
1696-
join_operator: JoinOperator::Cross,
1697-
}
1698-
}
1699-
_ => {
1700-
let natural = self.parse_keyword("NATURAL");
1701-
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
1702-
kw.keyword
1703-
} else {
1704-
String::default()
1705-
};
1706-
1707-
let join_operator_type = match peek_keyword.as_ref() {
1708-
"INNER" | "JOIN" => {
1709-
let _ = self.parse_keyword("INNER");
1710-
self.expect_keyword("JOIN")?;
1711-
JoinOperator::Inner
1712-
}
1713-
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
1714-
let _ = self.next_token();
1715-
let _ = self.parse_keyword("OUTER");
1716-
self.expect_keyword("JOIN")?;
1717-
match kw {
1718-
"LEFT" => JoinOperator::LeftOuter,
1719-
"RIGHT" => JoinOperator::RightOuter,
1720-
"FULL" => JoinOperator::FullOuter,
1721-
_ => unreachable!(),
1722-
}
1723-
}
1724-
_ if natural => {
1725-
return self.expected("a join type after NATURAL", self.peek_token());
1726-
}
1727-
_ => break,
1728-
};
1729-
let relation = self.parse_table_factor()?;
1730-
let join_constraint = self.parse_join_constraint(natural)?;
1731-
Join {
1732-
relation,
1733-
join_operator: join_operator_type(join_constraint),
1734-
}
1735-
}
1736-
};
1737-
joins.push(join);
1738-
}
1739-
Ok(joins)
1740-
}
1741-
17421740
/// Parse an INSERT statement
17431741
pub fn parse_insert(&mut self) -> Result<SQLStatement, ParserError> {
17441742
self.expect_keyword("INTO")?;

src/test_utils.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects {
109109
}
110110
}
111111

112-
pub fn only<T>(v: &[T]) -> &T {
113-
assert_eq!(1, v.len());
114-
v.first().unwrap()
112+
pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
113+
let mut iter = v.into_iter();
114+
if let (Some(item), None) = (iter.next(), iter.next()) {
115+
item
116+
} else {
117+
panic!("only called on collection without exactly one item")
118+
}
115119
}
116120

117121
pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode {

0 commit comments

Comments
 (0)