diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 719979663..f87bf4062 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -27,7 +27,7 @@ pub use self::ddl::{ }; pub use self::query::{ Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, - SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, + SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins, }; pub use self::sqltype::SQLType; pub use self::value::{SQLDateTimeField, Value}; diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index df07d48ae..150aeaa77 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -113,9 +113,7 @@ pub struct SQLSelect { /// projection expressions pub projection: Vec, /// FROM - pub relation: Option, - /// JOIN - pub joins: Vec, + pub from: Vec, /// WHERE pub selection: Option, /// GROUP BY @@ -131,11 +129,8 @@ impl ToString for SQLSelect { if self.distinct { " DISTINCT" } else { "" }, comma_separated_string(&self.projection) ); - if let Some(ref relation) = self.relation { - s += &format!(" FROM {}", relation.to_string()); - } - for join in &self.joins { - s += &join.to_string(); + if !self.from.is_empty() { + s += &format!(" FROM {}", comma_separated_string(&self.from)); } if let Some(ref selection) = self.selection { s += &format!(" WHERE {}", selection.to_string()); @@ -197,6 +192,22 @@ impl ToString for SQLSelectItem { } } +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct TableWithJoins { + pub relation: TableFactor, + pub joins: Vec, +} + +impl ToString for TableWithJoins { + fn to_string(&self) -> String { + let mut s = self.relation.to_string(); + for join in &self.joins { + s += &join.to_string(); + } + s + } +} + /// A table name or a parenthesized subquery with an optional alias #[derive(Debug, Clone, PartialEq, Hash)] pub enum TableFactor { @@ -215,10 +226,7 @@ pub enum TableFactor { subquery: Box, alias: Option, }, - NestedJoin { - base: Box, - joins: Vec, - }, + NestedJoin(Box), } impl ToString for TableFactor { @@ -257,12 +265,8 @@ impl ToString for TableFactor { } s } - TableFactor::NestedJoin { base, joins } => { - let mut s = base.to_string(); - for join in joins { - s += &join.to_string(); - } - format!("({})", s) + TableFactor::NestedJoin(table_reference) => { + format!("({})", table_reference.to_string()) } } } @@ -313,7 +317,6 @@ impl ToString for Join { suffix(constraint) ), JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), - JoinOperator::Implicit => format!(", {}", self.relation.to_string()), JoinOperator::LeftOuter(constraint) => format!( " {}LEFT JOIN {}{}", prefix(constraint), @@ -342,7 +345,6 @@ pub enum JoinOperator { LeftOuter(JoinConstraint), RightOuter(JoinConstraint), FullOuter(JoinConstraint), - Implicit, Cross, } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index be26265c4..4736f31ad 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -109,7 +109,7 @@ impl Parser { match self.next_token() { Some(t) => match t { Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() { - "SELECT" | "WITH" => { + "SELECT" | "WITH" | "VALUES" => { self.prev_token(); Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?))) } @@ -133,6 +133,10 @@ impl Parser { w.to_string() )), }, + Token::LParen => { + self.prev_token(); + Ok(SQLStatement::SQLQuery(Box::new(self.parse_query()?))) + } unexpected => self.expected( "a keyword at the beginning of a statement", Some(unexpected), @@ -1570,13 +1574,15 @@ impl Parser { } let projection = self.parse_select_list()?; - let (relation, joins) = if self.parse_keyword("FROM") { - let relation = Some(self.parse_table_factor()?); - let joins = self.parse_joins()?; - (relation, joins) - } else { - (None, vec![]) - }; + let mut from = vec![]; + if self.parse_keyword("FROM") { + loop { + from.push(self.parse_table_and_joins()?); + if !self.consume_token(&Token::Comma) { + break; + } + } + } let selection = if self.parse_keyword("WHERE") { Some(self.parse_expr()?) @@ -1599,14 +1605,69 @@ impl Parser { Ok(SQLSelect { distinct, projection, + from, selection, - relation, - joins, group_by, having, }) } + pub fn parse_table_and_joins(&mut self) -> Result { + let relation = self.parse_table_factor()?; + let mut joins = vec![]; + loop { + let join = match &self.peek_token() { + Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { + self.next_token(); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_table_factor()?, + join_operator: JoinOperator::Cross, + } + } + _ => { + let natural = self.parse_keyword("NATURAL"); + let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() { + kw.keyword + } else { + String::default() + }; + + let join_operator_type = match peek_keyword.as_ref() { + "INNER" | "JOIN" => { + let _ = self.parse_keyword("INNER"); + self.expect_keyword("JOIN")?; + JoinOperator::Inner + } + kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => { + let _ = self.next_token(); + let _ = self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + match kw { + "LEFT" => JoinOperator::LeftOuter, + "RIGHT" => JoinOperator::RightOuter, + "FULL" => JoinOperator::FullOuter, + _ => unreachable!(), + } + } + _ if natural => { + return self.expected("a join type after NATURAL", self.peek_token()); + } + _ => break, + }; + let relation = self.parse_table_factor()?; + let join_constraint = self.parse_join_constraint(natural)?; + Join { + relation, + join_operator: join_operator_type(join_constraint), + } + } + }; + joins.push(join); + } + Ok(TableWithJoins { relation, joins }) + } + /// A table name or a parenthesized subquery, followed by optional `[AS] alias` pub fn parse_table_factor(&mut self) -> Result { let lateral = self.parse_keyword("LATERAL"); @@ -1627,10 +1688,9 @@ impl Parser { } else if lateral { parser_err!("Expected subquery after LATERAL, found nested join".to_string()) } else { - let base = Box::new(self.parse_table_factor()?); - let joins = self.parse_joins()?; + let table_reference = self.parse_table_and_joins()?; self.expect_token(&Token::RParen)?; - Ok(TableFactor::NestedJoin { base, joins }) + Ok(TableFactor::NestedJoin(Box::new(table_reference))) } } else if lateral { self.expected("subquery after LATERAL", self.peek_token()) @@ -1677,68 +1737,6 @@ impl Parser { } } - fn parse_joins(&mut self) -> Result, ParserError> { - let mut joins = vec![]; - loop { - let join = match &self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::Implicit, - } - } - Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { - self.next_token(); - self.expect_keyword("JOIN")?; - Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::Cross, - } - } - _ => { - let natural = self.parse_keyword("NATURAL"); - let peek_keyword = if let Some(Token::SQLWord(kw)) = self.peek_token() { - kw.keyword - } else { - String::default() - }; - - let join_operator_type = match peek_keyword.as_ref() { - "INNER" | "JOIN" => { - let _ = self.parse_keyword("INNER"); - self.expect_keyword("JOIN")?; - JoinOperator::Inner - } - kw @ "LEFT" | kw @ "RIGHT" | kw @ "FULL" => { - let _ = self.next_token(); - let _ = self.parse_keyword("OUTER"); - self.expect_keyword("JOIN")?; - match kw { - "LEFT" => JoinOperator::LeftOuter, - "RIGHT" => JoinOperator::RightOuter, - "FULL" => JoinOperator::FullOuter, - _ => unreachable!(), - } - } - _ if natural => { - return self.expected("a join type after NATURAL", self.peek_token()); - } - _ => break, - }; - let relation = self.parse_table_factor()?; - let join_constraint = self.parse_join_constraint(natural)?; - Join { - relation, - join_operator: join_operator_type(join_constraint), - } - } - }; - joins.push(join); - } - Ok(joins) - } - /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { self.expect_keyword("INTO")?; diff --git a/src/test_utils.rs b/src/test_utils.rs index 16216bfe8..7a1ce5e2f 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects { } } -pub fn only(v: &[T]) -> &T { - assert_eq!(1, v.len()); - v.first().unwrap() +pub fn only(v: impl IntoIterator) -> T { + let mut iter = v.into_iter(); + if let (Some(item), None) = (iter.next(), iter.next()) { + item + } else { + panic!("only called on collection without exactly one item") + } } pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index f1b9ed078..6da8ee3ee 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -181,6 +181,14 @@ fn parse_where_delete_statement() { } } +#[test] +fn parse_top_level() { + verified_stmt("SELECT 1"); + verified_stmt("(SELECT 1)"); + verified_stmt("((SELECT 1))"); + verified_stmt("VALUES (1)"); +} + #[test] fn parse_simple_select() { let sql = "SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT 5"; @@ -1300,7 +1308,7 @@ fn parse_delimited_identifiers() { r#"SELECT "alias"."bar baz", "myfun"(), "simple id" AS "column alias" FROM "a table" AS "alias""# ); // check FROM - match select.relation.unwrap() { + match only(select.from).relation { TableFactor::Table { name, alias, @@ -1430,16 +1438,69 @@ fn parse_implicit_join() { let sql = "SELECT * FROM t1, t2"; let select = verified_only_select(sql); assert_eq!( - &Join { - relation: TableFactor::Table { - name: SQLObjectName(vec!["t2".to_string()]), - alias: None, - args: vec![], - with_hints: vec![], + vec![ + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], }, - join_operator: JoinOperator::Implicit - }, - only(&select.joins), + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], + } + ], + select.from, + ); + + let sql = "SELECT * FROM t1a NATURAL JOIN t1b, t2a NATURAL JOIN t2b"; + let select = verified_only_select(sql); + assert_eq!( + vec![ + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1a".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1b".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + }] + }, + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2a".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2b".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + }] + } + ], + select.from, ); } @@ -1448,7 +1509,7 @@ fn parse_cross_join() { let sql = "SELECT * FROM t1 CROSS JOIN t2"; let select = verified_only_select(sql); assert_eq!( - &Join { + Join { relation: TableFactor::Table { name: SQLObjectName(vec!["t2".to_string()]), alias: None, @@ -1457,7 +1518,7 @@ fn parse_cross_join() { }, join_operator: JoinOperator::Cross }, - only(&select.joins), + only(only(select.from).joins), ); } @@ -1491,7 +1552,7 @@ fn parse_joins_on() { } // Test parsing of aliases assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").from).joins, vec![join_with_constraint( "t2", table_alias("foo"), @@ -1504,19 +1565,19 @@ fn parse_joins_on() { ); // Test parsing of different join operators assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } @@ -1540,7 +1601,7 @@ fn parse_joins_using() { } // Test parsing of aliases assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").from).joins, vec![join_with_constraint( "t2", table_alias("foo"), @@ -1553,19 +1614,19 @@ fn parse_joins_using() { ); // Test parsing of different join operators assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } @@ -1584,19 +1645,19 @@ fn parse_natural_join() { } } assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").from).joins, vec![natural_join(JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").from).joins, vec![natural_join(JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").from).joins, vec![natural_join(JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").from).joins, vec![natural_join(JoinOperator::FullOuter)] ); @@ -1633,17 +1694,17 @@ fn parse_join_nesting() { macro_rules! nest { ($base:expr $(, $join:expr)*) => { - TableFactor::NestedJoin { - base: Box::new($base), + TableFactor::NestedJoin(Box::new(TableWithJoins { + relation: $base, joins: vec![$(join($join)),*] - } + })) }; } let sql = "SELECT * FROM a NATURAL JOIN (b NATURAL JOIN (c NATURAL JOIN d NATURAL JOIN e)) \ NATURAL JOIN (f NATURAL JOIN (g NATURAL JOIN h))"; assert_eq!( - verified_only_select(sql).joins, + only(&verified_only_select(sql).from).joins, vec![ join(nest!(table("b"), nest!(table("c"), table("d"), table("e")))), join(nest!(table("f"), nest!(table("g"), table("h")))) @@ -1652,22 +1713,22 @@ fn parse_join_nesting() { let sql = "SELECT * FROM (a NATURAL JOIN b) NATURAL JOIN c"; let select = verified_only_select(sql); - assert_eq!(select.relation.unwrap(), nest!(table("a"), table("b")),); - assert_eq!(select.joins, vec![join(table("c"))]); + let from = only(select.from); + assert_eq!(from.relation, nest!(table("a"), table("b"))); + assert_eq!(from.joins, vec![join(table("c"))]); let sql = "SELECT * FROM (((a NATURAL JOIN b)))"; let select = verified_only_select(sql); - assert_eq!( - select.relation.unwrap(), - nest!(nest!(nest!(table("a"), table("b")))) - ); - assert_eq!(select.joins, vec![]); + let from = only(select.from); + assert_eq!(from.relation, nest!(nest!(nest!(table("a"), table("b"))))); + assert_eq!(from.joins, vec![]); let sql = "SELECT * FROM a NATURAL JOIN (((b NATURAL JOIN c)))"; let select = verified_only_select(sql); - assert_eq!(select.relation.unwrap(), table("a")); + let from = only(select.from); + assert_eq!(from.relation, table("a")); assert_eq!( - select.joins, + from.joins, vec![join(nest!(nest!(nest!(table("b"), table("c")))))] ); } @@ -1729,8 +1790,8 @@ fn parse_ctes() { // CTE in a derived table let sql = &format!("SELECT * FROM ({})", with); let select = verified_only_select(sql); - match select.relation { - Some(TableFactor::Derived { subquery, .. }) => { + match only(select.from).relation { + TableFactor::Derived { subquery, .. } => { assert_ctes_in_select(&cte_sqls, subquery.as_ref()) } _ => panic!("Expected derived table"), @@ -2072,8 +2133,8 @@ fn parse_offset() { let ast = verified_query("SELECT foo FROM (SELECT * FROM bar OFFSET 2 ROWS) OFFSET 2 ROWS"); assert_eq!(ast.offset, Some(ASTNode::SQLValue(Value::Long(2)))); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2)))); } _ => panic!("Test broke"), @@ -2172,8 +2233,8 @@ fn parse_fetch() { }) ); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!( subquery.fetch, Some(Fetch { @@ -2198,8 +2259,8 @@ fn parse_fetch() { }) ); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2)))); assert_eq!( subquery.fetch, @@ -2250,16 +2311,18 @@ fn lateral_derived() { lateral_str ); let select = verified_only_select(&sql); - assert_eq!(select.joins.len(), 1); + let from = only(select.from); + assert_eq!(from.joins.len(), 1); + let join = &from.joins[0]; assert_eq!( - select.joins[0].join_operator, + join.join_operator, JoinOperator::LeftOuter(JoinConstraint::On(ASTNode::SQLValue(Value::Boolean(true)))) ); if let TableFactor::Derived { lateral, ref subquery, alias: Some(ref alias), - } = select.joins[0].relation + } = join.relation { assert_eq!(lateral_in, lateral); assert_eq!("order".to_string(), alias.name); diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b49cfd78f..d0b1f7ec4 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -19,8 +19,8 @@ fn parse_mssql_identifiers() { expr_from_projection(&select.projection[1]), ); assert_eq!(2, select.projection.len()); - match select.relation { - Some(TableFactor::Table { name, .. }) => { + match &only(&select.from).relation { + TableFactor::Table { name, .. } => { assert_eq!("##temp".to_string(), name.to_string()); } _ => unreachable!(),