Skip to content

Parse ARRAY_AGG for Bigquery and Snowflake #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ pub enum Expr {
ArraySubquery(Box<Query>),
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
ListAgg(ListAgg),
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
ArrayAgg(ArrayAgg),
/// The `GROUPING SETS` expr.
GroupingSets(Vec<Vec<Expr>>),
/// The `CUBE` expr.
Expand Down Expand Up @@ -655,6 +657,7 @@ impl fmt::Display for Expr {
Expr::Subquery(s) => write!(f, "({})", s),
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
Expr::GroupingSets(sets) => {
write!(f, "GROUPING SETS (")?;
let mut sep = "";
Expand Down Expand Up @@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
}
}

/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrayAgg {
pub distinct: bool,
pub expr: Box<Expr>,
pub order_by: Option<Box<OrderByExpr>>,
pub limit: Option<Box<Expr>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SuperBo any reason why you did the implementation for all dialects the same time?

You didn't complete none, but added parts of different ones. Got me a little confused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this. You can review again

pub within_group: bool, // order by is used inside a within group or not
}

impl fmt::Display for ArrayAgg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ARRAY_AGG({}{}",
if self.distinct { "DISTINCT " } else { "" },
self.expr
)?;
if !self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " ORDER BY {}", order_by)?;
}
if let Some(limit) = &self.limit {
write!(f, " LIMIT {}", limit)?;
}
}
write!(f, ")")?;
if self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
}
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ObjectType {
Expand Down
6 changes: 6 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
fn supports_within_after_array_aggregation(&self) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is confusing.

The comment make it sounds like if any of both is valid, it will return true.

Could be something like

/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
///
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function

Having the function here is not my usual approach, as it will require maintenance over other dialects individually, but I don't think it's wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @SuperBo, one thing:

  • Generic dialect should be the most permissive one so, if you actually can have both structures in a query, generic dialect should allow both. But considering my previous comment, I'd allow the most common one, maybe?

@alamb what do you think? This is a divergence from dialects. I'd go for the one with more dialects, so the generic is as generic as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, the ANSI one should be default option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AugustoFKL, are we good to go with this or should we change to more generics behavior?

false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|| ch == '$'
|| ch == '_'
}

fn supports_within_after_array_aggregation(&self) -> bool {
true
}
}
49 changes: 49 additions & 0 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::LParen)?;
self.parse_array_subquery()
}
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
Keyword::NOT => self.parse_not(),
// Here `w` is a word, check if it's a part of a multi-part
// identifier, a function call, or a simple identifier:
Expand Down Expand Up @@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
}))
}

pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let distinct = self.parse_keyword(Keyword::DISTINCT);
let expr = Box::new(self.parse_expr()?);
// ANSI SQL and BigQuery define ORDER BY inside function.
if !self.dialect.supports_within_after_array_aggregation() {
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
let order_by_expr = self.parse_order_by_expr()?;
Some(Box::new(order_by_expr))
} else {
None
};
let limit = if self.parse_keyword(Keyword::LIMIT) {
self.parse_limit()?.map(Box::new)
} else {
None
};
self.expect_token(&Token::RParen)?;
return Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by,
limit,
within_group: false,
}));
}
// Snowflake defines ORDERY BY in within group instead of inside the function like
// ANSI SQL.
self.expect_token(&Token::RParen)?;
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
self.expect_token(&Token::LParen)?;
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
let order_by_expr = self.parse_order_by_expr()?;
self.expect_token(&Token::RParen)?;
Some(Box::new(order_by_expr))
} else {
None
};

Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by: within_group,
limit: None,
within_group: true,
}))
}

// This function parses date/time fields for the EXTRACT function-like
// operator, interval qualifiers, and the ceil/floor operations.
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
Expand Down
11 changes: 11 additions & 0 deletions tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ fn parse_similar_to() {
chk(true);
}

#[test]
fn parse_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
bigquery().verified_stmt(sql);
}
}

fn bigquery() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(BigQueryDialect {})],
Expand Down
21 changes: 21 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,27 @@ fn parse_listagg() {
);
}

#[test]
fn parse_array_agg_func() {
let supported_dialects = TestedDialects {
dialects: vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(MsSqlDialect {}),
Box::new(AnsiDialect {}),
Box::new(HiveDialect {}),
],
};

for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
supported_dialects.verified_stmt(sql);
}
}

#[test]
fn parse_create_table() {
let sql = "CREATE TABLE uk_cities (\
Expand Down
8 changes: 4 additions & 4 deletions tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,17 @@ fn parse_create_function() {
#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
Expand Down
19 changes: 19 additions & 0 deletions tests/sqlparser_snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,25 @@ fn parse_similar_to() {
chk(true);
}

#[test]
fn test_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
] {
snowflake().verified_stmt(sql);
}

let sql = "select array_agg(x order by x) as a from T";
let result = snowflake().parse_sql_statements(sql);
assert_eq!(
result,
Err(ParserError::ParserError(String::from(
"Expected ), found: order"
)))
)
}

fn snowflake() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(SnowflakeDialect {})],
Expand Down