Skip to content

Refine join parsing #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sqlast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ pub enum TableFactor {
subquery: Box<SQLQuery>,
alias: Option<TableAlias>,
},
/// Represents a parenthesized join expression, such as
/// `(foo <JOIN> bar [ <JOIN> baz ... ])`.
/// The inner `TableWithJoins` can have no joins only if its
/// `relation` is itself a `TableFactor::NestedJoin`.
NestedJoin(Box<TableWithJoins>),
}

Expand Down
107 changes: 84 additions & 23 deletions src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ pub enum IsOptional {
}
use IsOptional::*;

pub enum IsLateral {
Lateral,
NotLateral,
}
use IsLateral::*;

impl From<TokenizerError> for ParserError {
fn from(e: TokenizerError) -> Self {
ParserError::TokenizerError(format!("{:?}", e))
Expand Down Expand Up @@ -1523,7 +1529,10 @@ impl Parser {
} else if self.parse_keyword("VALUES") {
SQLSetExpr::Values(self.parse_values()?)
} else {
return self.expected("SELECT or a subquery in the query body", self.peek_token());
return self.expected(
"SELECT, VALUES, or a subquery in the query body",
self.peek_token(),
);
};

loop {
Expand Down Expand Up @@ -1668,30 +1677,65 @@ impl Parser {

/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
let lateral = self.parse_keyword("LATERAL");
if self.parse_keyword("LATERAL") {
// LATERAL must always be followed by a subquery.
if !self.consume_token(&Token::LParen) {
self.expected("subquery after LATERAL", self.peek_token())?;
}
return self.parse_derived_table_factor(Lateral);
}

if self.consume_token(&Token::LParen) {
if self.parse_keyword("SELECT")
|| self.parse_keyword("WITH")
|| self.parse_keyword("VALUES")
{
self.prev_token();
let subquery = Box::new(self.parse_query()?);
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Derived {
lateral,
subquery,
alias,
})
} else if lateral {
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
} else {
let table_reference = self.parse_table_and_joins()?;
self.expect_token(&Token::RParen)?;
Ok(TableFactor::NestedJoin(Box::new(table_reference)))
let index = self.index;
// A left paren introduces either a derived table (i.e., a subquery)
// or a nested join. It's nearly impossible to determine ahead of
// time which it is... so we just try to parse both.
//
// Here's an example that demonstrates the complexity:
// /-------------------------------------------------------\
// | /-----------------------------------\ |
// SELECT * FROM ( ( ( (SELECT 1) UNION (SELECT 2) ) AS t1 NATURAL JOIN t2 ) )
// ^ ^ ^ ^
// | | | |
// | | | |
// | | | (4) belongs to a SQLSetExpr::Query inside the subquery
// | | (3) starts a derived table (subquery)
// | (2) starts a nested join
// (1) an additional set of parens around a nested join
//
match self.parse_derived_table_factor(NotLateral) {
// The recently consumed '(' started a derived table, and we've
// parsed the subquery, followed by the closing ')', and the
// alias of the derived table. In the example above this is
// case (3), and the next token would be `NATURAL`.
Ok(table_factor) => Ok(table_factor),
Err(_) => {
// The '(' we've recently consumed does not start a derived
// table. For valid input this can happen either when the
// token following the paren can't start a query (e.g. `foo`
// in `FROM (foo NATURAL JOIN bar)`, or when the '(' we've
// consumed is followed by another '(' that starts a
// derived table, like (3), or another nested join (2).
//
// Ignore the error and back up to where we were before.
// Either we'll be able to parse a valid nested join, or
// we won't, and we'll return that error instead.
self.index = index;
let table_and_joins = self.parse_table_and_joins()?;
match table_and_joins.relation {
TableFactor::NestedJoin { .. } => (),
_ => {
if table_and_joins.joins.is_empty() {
// The SQL spec prohibits derived tables and bare
// tables from appearing alone in parentheses.
self.expected("joined table", self.peek_token())?
}
}
}
self.expect_token(&Token::RParen)?;
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
}
}
} else if lateral {
self.expected("subquery after LATERAL", self.peek_token())
} else {
let name = self.parse_object_name()?;
// Postgres, MSSQL: table-valued functions:
Expand Down Expand Up @@ -1721,6 +1765,23 @@ impl Parser {
}
}

pub fn parse_derived_table_factor(
&mut self,
lateral: IsLateral,
) -> Result<TableFactor, ParserError> {
let subquery = Box::new(self.parse_query()?);
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Derived {
lateral: match lateral {
Lateral => true,
NotLateral => false,
},
subquery,
alias,
})
}

fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
Expand Down
46 changes: 43 additions & 3 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,12 @@ fn parse_join_nesting() {
from.joins,
vec![join(nest!(nest!(nest!(table("b"), table("c")))))]
);

let res = parse_sql_statements("SELECT * FROM (a NATURAL JOIN (b))");
assert_eq!(
ParserError::ParserError("Expected joined table, found: )".to_string()),
res.unwrap_err()
);
}

#[test]
Expand Down Expand Up @@ -1848,6 +1854,38 @@ fn parse_derived_tables() {
let sql = "SELECT * FROM t NATURAL JOIN (((SELECT 1)))";
let _ = verified_only_select(sql);
// TODO: add assertions

let sql = "SELECT * FROM (((SELECT 1) UNION (SELECT 2)) AS t1 NATURAL JOIN t2)";
let select = verified_only_select(sql);
let from = only(select.from);
assert_eq!(
from.relation,
TableFactor::NestedJoin(Box::new(TableWithJoins {
relation: TableFactor::Derived {
lateral: false,
subquery: Box::new(verified_query("(SELECT 1) UNION (SELECT 2)")),
alias: Some(TableAlias {
name: "t1".into(),
columns: vec![],
})
},
joins: vec![Join {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
join_operator: JoinOperator::Inner(JoinConstraint::Natural),
}],
}))
);

let res = parse_sql_statements("SELECT * FROM ((SELECT 1) AS t)");
assert_eq!(
ParserError::ParserError("Expected joined table, found: )".to_string()),
res.unwrap_err()
);
}

#[test]
Expand Down Expand Up @@ -1952,15 +1990,15 @@ fn parse_exists_subquery() {
let res = parse_sql_statements("SELECT EXISTS (");
assert_eq!(
ParserError::ParserError(
"Expected SELECT or a subquery in the query body, found: EOF".to_string()
"Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string()
),
res.unwrap_err(),
);

let res = parse_sql_statements("SELECT EXISTS (NULL)");
assert_eq!(
ParserError::ParserError(
"Expected SELECT or a subquery in the query body, found: NULL".to_string()
"Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string()
),
res.unwrap_err(),
);
Expand Down Expand Up @@ -2360,7 +2398,9 @@ fn lateral_derived() {
let sql = "SELECT * FROM a LEFT JOIN LATERAL (b CROSS JOIN c)";
let res = parse_sql_statements(sql);
assert_eq!(
ParserError::ParserError("Expected subquery after LATERAL, found nested join".to_string()),
ParserError::ParserError(
"Expected SELECT, VALUES, or a subquery in the query body, found: b".to_string()
),
res.unwrap_err()
);
}
Expand Down