diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2b81f0b39..63c4c739f 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -416,6 +416,8 @@ pub enum Expr { ArraySubquery(Box), /// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)` ListAgg(ListAgg), + /// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)` + ArrayAgg(ArrayAgg), /// The `GROUPING SETS` expr. GroupingSets(Vec>), /// The `CUBE` expr. @@ -655,6 +657,7 @@ impl fmt::Display for Expr { Expr::Subquery(s) => write!(f, "({})", s), Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), + Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg), Expr::GroupingSets(sets) => { write!(f, "GROUPING SETS (")?; let mut sep = ""; @@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow { } } +/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] [ORDER BY ] [LIMIT ] )` +/// Or `ARRAY_AGG( [ DISTINCT ] ) [ WITHIN GROUP ( ORDER BY ) ]` +/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ArrayAgg { + pub distinct: bool, + pub expr: Box, + pub order_by: Option>, + pub limit: Option>, + pub within_group: bool, // order by is used inside a within group or not +} + +impl fmt::Display for ArrayAgg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "ARRAY_AGG({}{}", + if self.distinct { "DISTINCT " } else { "" }, + self.expr + )?; + if !self.within_group { + if let Some(order_by) = &self.order_by { + write!(f, " ORDER BY {}", order_by)?; + } + if let Some(limit) = &self.limit { + write!(f, " LIMIT {}", limit)?; + } + } + write!(f, ")")?; + if self.within_group { + if let Some(order_by) = &self.order_by { + write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?; + } + } + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ObjectType { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1d3c9cf5f..1eaa41aa7 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any { fn supports_filter_during_aggregation(&self) -> bool { false } + /// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions. + /// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)]. + /// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function + fn supports_within_after_array_aggregation(&self) -> bool { + false + } /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { // return None to fall back to the default behavior diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 93db95692..11108e973 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect { || ch == '$' || ch == '_' } + + fn supports_within_after_array_aggregation(&self) -> bool { + true + } } diff --git a/src/parser.rs b/src/parser.rs index 0753d263e..201b344de 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -473,6 +473,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::LParen)?; self.parse_array_subquery() } + Keyword::ARRAY_AGG => self.parse_array_agg_expr(), Keyword::NOT => self.parse_not(), // Here `w` is a word, check if it's a part of a multi-part // identifier, a function call, or a simple identifier: @@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> { })) } + pub fn parse_array_agg_expr(&mut self) -> Result { + self.expect_token(&Token::LParen)?; + let distinct = self.parse_keyword(Keyword::DISTINCT); + let expr = Box::new(self.parse_expr()?); + // ANSI SQL and BigQuery define ORDER BY inside function. + if !self.dialect.supports_within_after_array_aggregation() { + let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { + let order_by_expr = self.parse_order_by_expr()?; + Some(Box::new(order_by_expr)) + } else { + None + }; + let limit = if self.parse_keyword(Keyword::LIMIT) { + self.parse_limit()?.map(Box::new) + } else { + None + }; + self.expect_token(&Token::RParen)?; + return Ok(Expr::ArrayAgg(ArrayAgg { + distinct, + expr, + order_by, + limit, + within_group: false, + })); + } + // Snowflake defines ORDERY BY in within group instead of inside the function like + // ANSI SQL. + self.expect_token(&Token::RParen)?; + let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) { + self.expect_token(&Token::LParen)?; + self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; + let order_by_expr = self.parse_order_by_expr()?; + self.expect_token(&Token::RParen)?; + Some(Box::new(order_by_expr)) + } else { + None + }; + + Ok(Expr::ArrayAgg(ArrayAgg { + distinct, + expr, + order_by: within_group, + limit: None, + within_group: true, + })) + } + // This function parses date/time fields for the EXTRACT function-like // operator, interval qualifiers, and the ceil/floor operations. // EXTRACT supports a wider set of date/time fields than interval qualifiers, diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 86b47ddad..8ada172cf 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -224,6 +224,17 @@ fn parse_similar_to() { chk(true); } +#[test] +fn parse_array_agg_func() { + for sql in [ + "SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T", + "SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl", + "SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl", + ] { + bigquery().verified_stmt(sql); + } +} + fn bigquery() -> TestedDialects { TestedDialects { dialects: vec![Box::new(BigQueryDialect {})], diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4efb8cc7c..e3390a479 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1777,6 +1777,27 @@ fn parse_listagg() { ); } +#[test] +fn parse_array_agg_func() { + let supported_dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(HiveDialect {}), + ], + }; + + for sql in [ + "SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T", + "SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl", + "SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl", + ] { + supported_dialects.verified_stmt(sql); + } +} + #[test] fn parse_create_table() { let sql = "CREATE TABLE uk_cities (\ diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 695f63b54..070f55089 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -281,8 +281,8 @@ fn parse_create_function() { #[test] fn filtering_during_aggregation() { let rename = "SELECT \ - array_agg(name) FILTER (WHERE name IS NOT NULL), \ - array_agg(name) FILTER (WHERE name LIKE 'a%') \ + ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \ + ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \ FROM region"; println!("{}", hive().verified_stmt(rename)); } @@ -290,8 +290,8 @@ fn filtering_during_aggregation() { #[test] fn filtering_during_aggregation_aliased() { let rename = "SELECT \ - array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \ - array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \ + ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \ + ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \ FROM region"; println!("{}", hive().verified_stmt(rename)); } diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 2a53b0840..a201c5db7 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -334,6 +334,25 @@ fn parse_similar_to() { chk(true); } +#[test] +fn test_array_agg_func() { + for sql in [ + "SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T", + "SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl", + ] { + snowflake().verified_stmt(sql); + } + + let sql = "select array_agg(x order by x) as a from T"; + let result = snowflake().parse_sql_statements(sql); + assert_eq!( + result, + Err(ParserError::ParserError(String::from( + "Expected ), found: order" + ))) + ) +} + fn snowflake() -> TestedDialects { TestedDialects { dialects: vec![Box::new(SnowflakeDialect {})],