Skip to content

Properly handle mixed implicit and explicit joins #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sqlast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use self::ddl::{
};
pub use self::query::{
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor,
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins,
};
pub use self::sqltype::SQLType;
pub use self::value::{SQLDateTimeField, Value};
Expand Down
42 changes: 22 additions & 20 deletions src/sqlast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ pub struct SQLSelect {
/// projection expressions
pub projection: Vec<SQLSelectItem>,
/// FROM
pub relation: Option<TableFactor>,
/// JOIN
pub joins: Vec<Join>,
pub from: Vec<TableWithJoins>,
/// WHERE
pub selection: Option<ASTNode>,
/// GROUP BY
Expand All @@ -131,11 +129,8 @@ impl ToString for SQLSelect {
if self.distinct { " DISTINCT" } else { "" },
comma_separated_string(&self.projection)
);
if let Some(ref relation) = self.relation {
s += &format!(" FROM {}", relation.to_string());
}
for join in &self.joins {
s += &join.to_string();
if !self.from.is_empty() {
s += &format!(" FROM {}", comma_separated_string(&self.from));
}
if let Some(ref selection) = self.selection {
s += &format!(" WHERE {}", selection.to_string());
Expand Down Expand Up @@ -197,6 +192,22 @@ impl ToString for SQLSelectItem {
}
}

#[derive(Debug, Clone, PartialEq, Hash)]
pub struct TableWithJoins {
pub relation: TableFactor,
pub joins: Vec<Join>,
}

impl ToString for TableWithJoins {
fn to_string(&self) -> String {
let mut s = self.relation.to_string();
for join in &self.joins {
s += &join.to_string();
}
s
}
}

/// A table name or a parenthesized subquery with an optional alias
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum TableFactor {
Expand All @@ -215,10 +226,7 @@ pub enum TableFactor {
subquery: Box<SQLQuery>,
alias: Option<TableAlias>,
},
NestedJoin {
base: Box<TableFactor>,
joins: Vec<Join>,
},
NestedJoin(Box<TableWithJoins>),
}

impl ToString for TableFactor {
Expand Down Expand Up @@ -257,12 +265,8 @@ impl ToString for TableFactor {
}
s
}
TableFactor::NestedJoin { base, joins } => {
let mut s = base.to_string();
for join in joins {
s += &join.to_string();
}
format!("({})", s)
TableFactor::NestedJoin(table_reference) => {
format!("({})", table_reference.to_string())
}
}
}
Expand Down Expand Up @@ -313,7 +317,6 @@ impl ToString for Join {
suffix(constraint)
),
JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()),
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
JoinOperator::LeftOuter(constraint) => format!(
" {}LEFT JOIN {}{}",
prefix(constraint),
Expand Down Expand Up @@ -342,7 +345,6 @@ pub enum JoinOperator {
LeftOuter(JoinConstraint),
RightOuter(JoinConstraint),
FullOuter(JoinConstraint),
Implicit,
Cross,
}

Expand Down
148 changes: 73 additions & 75 deletions src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl Parser {
match self.next_token() {
Some(t) => match t {
Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() {
"SELECT" | "WITH" => {
"SELECT" | "WITH" | "VALUES" => {
self.prev_token();
Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?)))
}
Expand All @@ -133,6 +133,10 @@ impl Parser {
w.to_string()
)),
},
Token::LParen => {
self.prev_token();
Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?)))
}
unexpected => self.expected(
"a keyword at the beginning of a statement",
Some(unexpected),
Expand Down Expand Up @@ -1570,13 +1574,15 @@ impl Parser {
}
let projection = self.parse_select_list()?;

let (relation, joins) = if self.parse_keyword("FROM") {
let relation = Some(self.parse_table_factor()?);
let joins = self.parse_joins()?;
(relation, joins)
} else {
(None, vec![])
};
let mut from = vec![];
if self.parse_keyword("FROM") {
loop {
from.push(self.parse_table_and_joins()?);
if !self.consume_token(&Token::Comma) {
break;
}
}
}

let selection = if self.parse_keyword("WHERE") {
Some(self.parse_expr()?)
Expand All @@ -1599,14 +1605,69 @@ impl Parser {
Ok(SQLSelect {
distinct,
projection,
from,
selection,
relation,
joins,
group_by,
having,
})
}

pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {
let relation = self.parse_table_factor()?;
let mut joins = vec![];
loop {
let join = match &self.peek_token() {
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
self.next_token();
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_table_factor()?,
join_operator: JoinOperator::Cross,
}
}
_ => {
let natural = self.parse_keyword("NATURAL");
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
kw.keyword
} else {
String::default()
};

let join_operator_type = match peek_keyword.as_ref() {
"INNER" | "JOIN" => {
let _ = self.parse_keyword("INNER");
self.expect_keyword("JOIN")?;
JoinOperator::Inner
}
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
let _ = self.next_token();
let _ = self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
match kw {
"LEFT" => JoinOperator::LeftOuter,
"RIGHT" => JoinOperator::RightOuter,
"FULL" => JoinOperator::FullOuter,
_ => unreachable!(),
}
}
_ if natural => {
return self.expected("a join type after NATURAL", self.peek_token());
}
_ => break,
};
let relation = self.parse_table_factor()?;
let join_constraint = self.parse_join_constraint(natural)?;
Join {
relation,
join_operator: join_operator_type(join_constraint),
}
}
};
joins.push(join);
}
Ok(TableWithJoins { relation, joins })
}

/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
let lateral = self.parse_keyword("LATERAL");
Expand All @@ -1627,10 +1688,9 @@ impl Parser {
} else if lateral {
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
} else {
let base = Box::new(self.parse_table_factor()?);
let joins = self.parse_joins()?;
let table_reference = self.parse_table_and_joins()?;
self.expect_token(&Token::RParen)?;
Ok(TableFactor::NestedJoin { base, joins })
Ok(TableFactor::NestedJoin(Box::new(table_reference)))
}
} else if lateral {
self.expected("subquery after LATERAL", self.peek_token())
Expand Down Expand Up @@ -1677,68 +1737,6 @@ impl Parser {
}
}

fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
let mut joins = vec![];
loop {
let join = match &self.peek_token() {
Some(Token::Comma) => {
self.next_token();
Join {
relation: self.parse_table_factor()?,
join_operator: JoinOperator::Implicit,
}
}
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
self.next_token();
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_table_factor()?,
join_operator: JoinOperator::Cross,
}
}
_ => {
let natural = self.parse_keyword("NATURAL");
let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() {
kw.keyword
} else {
String::default()
};

let join_operator_type = match peek_keyword.as_ref() {
"INNER" | "JOIN" => {
let _ = self.parse_keyword("INNER");
self.expect_keyword("JOIN")?;
JoinOperator::Inner
}
kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => {
let _ = self.next_token();
let _ = self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
match kw {
"LEFT" => JoinOperator::LeftOuter,
"RIGHT" => JoinOperator::RightOuter,
"FULL" => JoinOperator::FullOuter,
_ => unreachable!(),
}
}
_ if natural => {
return self.expected("a join type after NATURAL", self.peek_token());
}
_ => break,
};
let relation = self.parse_table_factor()?;
let join_constraint = self.parse_join_constraint(natural)?;
Join {
relation,
join_operator: join_operator_type(join_constraint),
}
}
};
joins.push(join);
}
Ok(joins)
}

/// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<SQLStatement, ParserError> {
self.expect_keyword("INTO")?;
Expand Down
10 changes: 7 additions & 3 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects {
}
}

pub fn only<T>(v: &[T]) -> &T {
assert_eq!(1, v.len());
v.first().unwrap()
pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
let mut iter = v.into_iter();
if let (Some(item), None) = (iter.next(), iter.next()) {
item
} else {
panic!("only called on collection without exactly one item")
}
}

pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode {
Expand Down
Loading