diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index d6c112bc9..967f23b63 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -238,6 +238,10 @@ pub enum TableFactor { subquery: Box, alias: Option, }, + /// Represents a parenthesized join expression, such as + /// `(foo bar [ baz ... ])`. + /// The inner `TableWithJoins` can have no joins only if its + /// `relation` is itself a `TableFactor::NestedJoin`. NestedJoin(Box), } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 2836b8d15..f3e2e5e05 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -40,6 +40,12 @@ pub enum IsOptional { } use IsOptional::*; +pub enum IsLateral { + Lateral, + NotLateral, +} +use IsLateral::*; + impl From for ParserError { fn from(e: TokenizerError) -> Self { ParserError::TokenizerError(format!("{:?}", e)) @@ -1523,7 +1529,10 @@ impl Parser { } else if self.parse_keyword("VALUES") { SQLSetExpr::Values(self.parse_values()?) } else { - return self.expected("SELECT or a subquery in the query body", self.peek_token()); + return self.expected( + "SELECT, VALUES, or a subquery in the query body", + self.peek_token(), + ); }; loop { @@ -1668,30 +1677,65 @@ impl Parser { /// A table name or a parenthesized subquery, followed by optional `[AS] alias` pub fn parse_table_factor(&mut self) -> Result { - let lateral = self.parse_keyword("LATERAL"); + if self.parse_keyword("LATERAL") { + // LATERAL must always be followed by a subquery. + if !self.consume_token(&Token::LParen) { + self.expected("subquery after LATERAL", self.peek_token())?; + } + return self.parse_derived_table_factor(Lateral); + } + if self.consume_token(&Token::LParen) { - if self.parse_keyword("SELECT") - || self.parse_keyword("WITH") - || self.parse_keyword("VALUES") - { - self.prev_token(); - let subquery = Box::new(self.parse_query()?); - self.expect_token(&Token::RParen)?; - let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; - Ok(TableFactor::Derived { - lateral, - subquery, - alias, - }) - } else if lateral { - parser_err!("Expected subquery after LATERAL, found nested join".to_string()) - } else { - let table_reference = self.parse_table_and_joins()?; - self.expect_token(&Token::RParen)?; - Ok(TableFactor::NestedJoin(Box::new(table_reference))) + let index = self.index; + // A left paren introduces either a derived table (i.e., a subquery) + // or a nested join. It's nearly impossible to determine ahead of + // time which it is... so we just try to parse both. + // + // Here's an example that demonstrates the complexity: + // /-------------------------------------------------------\ + // | /-----------------------------------\ | + // SELECT * FROM ( ( ( (SELECT 1) UNION (SELECT 2) ) AS t1 NATURAL JOIN t2 ) ) + // ^ ^ ^ ^ + // | | | | + // | | | | + // | | | (4) belongs to a SQLSetExpr::Query inside the subquery + // | | (3) starts a derived table (subquery) + // | (2) starts a nested join + // (1) an additional set of parens around a nested join + // + match self.parse_derived_table_factor(NotLateral) { + // The recently consumed '(' started a derived table, and we've + // parsed the subquery, followed by the closing ')', and the + // alias of the derived table. In the example above this is + // case (3), and the next token would be `NATURAL`. + Ok(table_factor) => Ok(table_factor), + Err(_) => { + // The '(' we've recently consumed does not start a derived + // table. For valid input this can happen either when the + // token following the paren can't start a query (e.g. `foo` + // in `FROM (foo NATURAL JOIN bar)`, or when the '(' we've + // consumed is followed by another '(' that starts a + // derived table, like (3), or another nested join (2). + // + // Ignore the error and back up to where we were before. + // Either we'll be able to parse a valid nested join, or + // we won't, and we'll return that error instead. + self.index = index; + let table_and_joins = self.parse_table_and_joins()?; + match table_and_joins.relation { + TableFactor::NestedJoin { .. } => (), + _ => { + if table_and_joins.joins.is_empty() { + // The SQL spec prohibits derived tables and bare + // tables from appearing alone in parentheses. + self.expected("joined table", self.peek_token())? + } + } + } + self.expect_token(&Token::RParen)?; + Ok(TableFactor::NestedJoin(Box::new(table_and_joins))) + } } - } else if lateral { - self.expected("subquery after LATERAL", self.peek_token()) } else { let name = self.parse_object_name()?; // Postgres, MSSQL: table-valued functions: @@ -1721,6 +1765,23 @@ impl Parser { } } + pub fn parse_derived_table_factor( + &mut self, + lateral: IsLateral, + ) -> Result { + let subquery = Box::new(self.parse_query()?); + self.expect_token(&Token::RParen)?; + let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Derived { + lateral: match lateral { + Lateral => true, + NotLateral => false, + }, + subquery, + alias, + }) + } + fn parse_join_constraint(&mut self, natural: bool) -> Result { if natural { Ok(JoinConstraint::Natural) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4b2ab958c..a013d022b 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1742,6 +1742,12 @@ fn parse_join_nesting() { from.joins, vec![join(nest!(nest!(nest!(table("b"), table("c")))))] ); + + let res = parse_sql_statements("SELECT * FROM (a NATURAL JOIN (b))"); + assert_eq!( + ParserError::ParserError("Expected joined table, found: )".to_string()), + res.unwrap_err() + ); } #[test] @@ -1848,6 +1854,38 @@ fn parse_derived_tables() { let sql = "SELECT * FROM t NATURAL JOIN (((SELECT 1)))"; let _ = verified_only_select(sql); // TODO: add assertions + + let sql = "SELECT * FROM (((SELECT 1) UNION (SELECT 2)) AS t1 NATURAL JOIN t2)"; + let select = verified_only_select(sql); + let from = only(select.from); + assert_eq!( + from.relation, + TableFactor::NestedJoin(Box::new(TableWithJoins { + relation: TableFactor::Derived { + lateral: false, + subquery: Box::new(verified_query("(SELECT 1) UNION (SELECT 2)")), + alias: Some(TableAlias { + name: "t1".into(), + columns: vec![], + }) + }, + joins: vec![Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + }], + })) + ); + + let res = parse_sql_statements("SELECT * FROM ((SELECT 1) AS t)"); + assert_eq!( + ParserError::ParserError("Expected joined table, found: )".to_string()), + res.unwrap_err() + ); } #[test] @@ -1952,7 +1990,7 @@ fn parse_exists_subquery() { let res = parse_sql_statements("SELECT EXISTS ("); assert_eq!( ParserError::ParserError( - "Expected SELECT or a subquery in the query body, found: EOF".to_string() + "Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string() ), res.unwrap_err(), ); @@ -1960,7 +1998,7 @@ fn parse_exists_subquery() { let res = parse_sql_statements("SELECT EXISTS (NULL)"); assert_eq!( ParserError::ParserError( - "Expected SELECT or a subquery in the query body, found: NULL".to_string() + "Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string() ), res.unwrap_err(), ); @@ -2360,7 +2398,9 @@ fn lateral_derived() { let sql = "SELECT * FROM a LEFT JOIN LATERAL (b CROSS JOIN c)"; let res = parse_sql_statements(sql); assert_eq!( - ParserError::ParserError("Expected subquery after LATERAL, found nested join".to_string()), + ParserError::ParserError( + "Expected SELECT, VALUES, or a subquery in the query body, found: b".to_string() + ), res.unwrap_err() ); }