diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index 663e0e508..df07d48ae 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -215,6 +215,10 @@ pub enum TableFactor { subquery: Box, alias: Option, }, + NestedJoin { + base: Box, + joins: Vec, + }, } impl ToString for TableFactor { @@ -253,6 +257,13 @@ impl ToString for TableFactor { } s } + TableFactor::NestedJoin { base, joins } => { + let mut s = base.to_string(); + for join in joins { + s += &join.to_string(); + } + format!("({})", s) + } } } } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 1b1b83b61..9629fc57a 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -1537,14 +1537,27 @@ impl Parser { pub fn parse_table_factor(&mut self) -> Result { let lateral = self.parse_keyword("LATERAL"); if self.consume_token(&Token::LParen) { - let subquery = Box::new(self.parse_query()?); - self.expect_token(&Token::RParen)?; - let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; - Ok(TableFactor::Derived { - lateral, - subquery, - alias, - }) + if self.parse_keyword("SELECT") + || self.parse_keyword("WITH") + || self.parse_keyword("VALUES") + { + self.prev_token(); + let subquery = Box::new(self.parse_query()?); + self.expect_token(&Token::RParen)?; + let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Derived { + lateral, + subquery, + alias, + }) + } else if lateral { + parser_err!("Expected subquery after LATERAL, found nested join".to_string()) + } else { + let base = Box::new(self.parse_table_factor()?); + let joins = self.parse_joins()?; + self.expect_token(&Token::RParen)?; + Ok(TableFactor::NestedJoin { base, joins }) + } } else if lateral { self.expected("subquery after LATERAL", self.peek_token()) } else { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 1c051ca00..612da93a1 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1447,6 +1447,49 @@ fn parse_complex_join() { verified_only_select(sql); } +#[test] +fn parse_join_nesting() { + fn table(name: impl Into) -> TableFactor { + TableFactor::Table { + name: SQLObjectName(vec![name.into()]), + alias: None, + args: vec![], + with_hints: vec![], + } + } + + fn join(relation: TableFactor) -> Join { + Join { + relation, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + } + } + + macro_rules! nest { + ($base:expr, $($join:expr),*) => { + TableFactor::NestedJoin { + base: Box::new($base), + joins: vec![$(join($join)),*] + } + }; + } + + let sql = "SELECT * FROM a NATURAL JOIN (b NATURAL JOIN (c NATURAL JOIN d NATURAL JOIN e)) \ + NATURAL JOIN (f NATURAL JOIN (g NATURAL JOIN h))"; + assert_eq!( + verified_only_select(sql).joins, + vec![ + join(nest!(table("b"), nest!(table("c"), table("d"), table("e")))), + join(nest!(table("f"), nest!(table("g"), table("h")))) + ], + ); + + let sql = "SELECT * FROM (a NATURAL JOIN b) NATURAL JOIN c"; + let select = verified_only_select(sql); + assert_eq!(select.relation.unwrap(), nest!(table("a"), table("b")),); + assert_eq!(select.joins, vec![join(table("c"))],) +} + #[test] fn parse_join_syntax_variants() { one_statement_parses_to( @@ -2049,6 +2092,13 @@ fn lateral_derived() { ), res.unwrap_err() ); + + let sql = "SELECT * FROM a LEFT JOIN LATERAL (b CROSS JOIN c)"; + let res = parse_sql_statements(sql); + assert_eq!( + ParserError::ParserError("Expected subquery after LATERAL, found nested join".to_string()), + res.unwrap_err() + ); } #[test]