diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index bbf359887..297d6e511 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -271,14 +271,14 @@ impl ToString for Join { } fn suffix(constraint: &JoinConstraint) -> String { match constraint { - JoinConstraint::On(expr) => format!("ON {}", expr.to_string()), - JoinConstraint::Using(attrs) => format!("USING({})", attrs.join(", ")), + JoinConstraint::On(expr) => format!(" ON {}", expr.to_string()), + JoinConstraint::Using(attrs) => format!(" USING({})", attrs.join(", ")), _ => "".to_string(), } } match &self.join_operator { JoinOperator::Inner(constraint) => format!( - " {}JOIN {} {}", + " {}JOIN {}{}", prefix(constraint), self.relation.to_string(), suffix(constraint) @@ -286,19 +286,19 @@ impl ToString for Join { JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), JoinOperator::Implicit => format!(", {}", self.relation.to_string()), JoinOperator::LeftOuter(constraint) => format!( - " {}LEFT JOIN {} {}", + " {}LEFT JOIN {}{}", prefix(constraint), self.relation.to_string(), suffix(constraint) ), JoinOperator::RightOuter(constraint) => format!( - " {}RIGHT JOIN {} {}", + " {}RIGHT JOIN {}{}", prefix(constraint), self.relation.to_string(), suffix(constraint) ), JoinOperator::FullOuter(constraint) => format!( - " {}FULL JOIN {} {}", + " {}FULL JOIN {}{}", prefix(constraint), self.relation.to_string(), suffix(constraint) diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 85121eef4..41ca25617 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -702,14 +702,10 @@ impl Parser { /// Consume the next token if it matches the expected token, otherwise return false #[must_use] pub fn consume_token(&mut self, expected: &Token) -> bool { - match self.peek_token() { - Some(ref t) => { - if *t == *expected { - self.next_token(); - true - } else { - false - } + match &self.peek_token() { + Some(t) if *t == *expected => { + self.next_token(); + true } _ => false, } @@ -1503,90 +1499,62 @@ impl Parser { fn parse_joins(&mut self) -> Result, ParserError> { let mut joins = vec![]; loop { - let natural = match &self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - let relation = self.parse_table_factor()?; - let join = Join { - relation, - join_operator: JoinOperator::Implicit, - }; - joins.push(join); - continue; - } - Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { - self.next_token(); - self.expect_keyword("JOIN")?; - let relation = self.parse_table_factor()?; - let join = Join { - relation, - join_operator: JoinOperator::Cross, - }; - joins.push(join); - continue; - } - Some(Token::SQLWord(kw)) if kw.keyword == "NATURAL" => { - self.next_token(); - true - } - Some(_) => false, - None => return Ok(joins), - }; - let join = match &self.peek_token() { - Some(Token::SQLWord(kw)) if kw.keyword == "INNER" => { - self.next_token(); - self.expect_keyword("JOIN")?; - Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), - } - } - Some(Token::SQLWord(kw)) if kw.keyword == "JOIN" => { - self.next_token(); - Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), - } - } - Some(Token::SQLWord(kw)) if kw.keyword == "LEFT" => { + Some(Token::Comma) => { self.next_token(); - let _ = self.parse_keyword("OUTER"); - self.expect_keyword("JOIN")?; Join { relation: self.parse_table_factor()?, - join_operator: JoinOperator::LeftOuter( - self.parse_join_constraint(natural)?, - ), + join_operator: JoinOperator::Implicit, } } - Some(Token::SQLWord(kw)) if kw.keyword == "RIGHT" => { + Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { self.next_token(); - let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { relation: self.parse_table_factor()?, - join_operator: JoinOperator::RightOuter( - self.parse_join_constraint(natural)?, - ), + join_operator: JoinOperator::Cross, } } - Some(Token::SQLWord(kw)) if kw.keyword == "FULL" => { - self.next_token(); - let _ = self.parse_keyword("OUTER"); - self.expect_keyword("JOIN")?; + _ => { + let natural = self.parse_keyword("NATURAL"); + let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() { + kw.keyword + } else { + String::default() + }; + + let join_operator_type = match peek_keyword.as_ref() { + "INNER" | "JOIN" => { + let _ = self.parse_keyword("INNER"); + self.expect_keyword("JOIN")?; + JoinOperator::Inner + } + kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => { + let _ = self.next_token(); + let _ = self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + match kw { + "LEFT" => JoinOperator::LeftOuter, + "RIGHT" => JoinOperator::RightOuter, + "FULL" => JoinOperator::FullOuter, + _ => unreachable!(), + } + } + _ if natural => { + return self.expected("a join type after NATURAL", self.peek_token()); + } + _ => break, + }; + let relation = self.parse_table_factor()?; + let join_constraint = self.parse_join_constraint(natural)?; Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::FullOuter( - self.parse_join_constraint(natural)?, - ), + relation, + join_operator: join_operator_type(join_constraint), } } - _ => break, }; joins.push(join); } - Ok(joins) } @@ -1611,10 +1579,9 @@ impl Parser { let mut expr_list: Vec = vec![]; loop { expr_list.push(self.parse_expr()?); - match self.peek_token() { - Some(Token::Comma) => self.next_token(), - _ => break, - }; + if !self.consume_token(&Token::Comma) { + break; + } } Ok(expr_list) } @@ -1649,10 +1616,9 @@ impl Parser { } } - match self.peek_token() { - Some(Token::Comma) => self.next_token(), - _ => break, - }; + if !self.consume_token(&Token::Comma) { + break; + } } Ok(projections) } @@ -1672,10 +1638,7 @@ impl Parser { }; expr_list.push(SQLOrderByExpr { expr, asc }); - - if let Some(Token::Comma) = self.peek_token() { - self.next_token(); - } else { + if !self.consume_token(&Token::Comma) { break; } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ab5ce4d2a..56a5bff35 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1218,6 +1218,43 @@ fn parse_joins_using() { ); } +#[test] +fn parse_natural_join() { + fn natural_join(f: impl Fn(JoinConstraint) -> JoinOperator) -> Join { + Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".to_string()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: f(JoinConstraint::Natural), + } + } + assert_eq!( + verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").joins, + vec![natural_join(JoinOperator::Inner)] + ); + assert_eq!( + verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").joins, + vec![natural_join(JoinOperator::LeftOuter)] + ); + assert_eq!( + verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").joins, + vec![natural_join(JoinOperator::RightOuter)] + ); + assert_eq!( + verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").joins, + vec![natural_join(JoinOperator::FullOuter)] + ); + + let sql = "SELECT * FROM t1 natural"; + assert_eq!( + ParserError::ParserError("Expected a join type after NATURAL, found: EOF".to_string()), + parse_sql_statements(sql).unwrap_err(), + ); +} + #[test] fn parse_complex_join() { let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c";