Skip to content

Commit b3b4058

Browse files
SuperBoMazterQyou
authored andcommitted
Parse ARRAY_AGG for Bigquery and Snowflake (apache#662)
1 parent 10782e5 commit b3b4058

File tree

6 files changed

+143
-2
lines changed

6 files changed

+143
-2
lines changed

src/ast/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ pub enum Expr {
343343
Subquery(Box<Query>),
344344
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
345345
ListAgg(ListAgg),
346+
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
347+
ArrayAgg(ArrayAgg),
346348
/// The `GROUPING SETS` expr.
347349
GroupingSets(Vec<Vec<Expr>>),
348350
/// The `CUBE` expr.
@@ -542,6 +544,7 @@ impl fmt::Display for Expr {
542544
Expr::Subquery(s) => write!(f, "({})", s),
543545
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
544546
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
547+
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
545548
Expr::GroupingSets(sets) => {
546549
write!(f, "GROUPING SETS (")?;
547550
let mut sep = "";
@@ -2448,6 +2451,45 @@ impl fmt::Display for ListAggOnOverflow {
24482451
}
24492452
}
24502453

2454+
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
2455+
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
2456+
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
2457+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2458+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2459+
pub struct ArrayAgg {
2460+
pub distinct: bool,
2461+
pub expr: Box<Expr>,
2462+
pub order_by: Option<Box<OrderByExpr>>,
2463+
pub limit: Option<Box<Expr>>,
2464+
pub within_group: bool, // order by is used inside a within group or not
2465+
}
2466+
2467+
impl fmt::Display for ArrayAgg {
2468+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2469+
write!(
2470+
f,
2471+
"ARRAY_AGG({}{}",
2472+
if self.distinct { "DISTINCT " } else { "" },
2473+
self.expr
2474+
)?;
2475+
if !self.within_group {
2476+
if let Some(order_by) = &self.order_by {
2477+
write!(f, " ORDER BY {}", order_by)?;
2478+
}
2479+
if let Some(limit) = &self.limit {
2480+
write!(f, " LIMIT {}", limit)?;
2481+
}
2482+
}
2483+
write!(f, ")")?;
2484+
if self.within_group {
2485+
if let Some(order_by) = &self.order_by {
2486+
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
2487+
}
2488+
}
2489+
Ok(())
2490+
}
2491+
}
2492+
24512493
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24522494
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24532495
pub enum ObjectType {

src/dialect/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ pub trait Dialect: Debug + Any {
6363
fn is_identifier_start(&self, ch: char) -> bool;
6464
/// Determine if a character is a valid unquoted identifier character
6565
fn is_identifier_part(&self, ch: char) -> bool;
66+
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
67+
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
68+
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
69+
fn supports_within_after_array_aggregation(&self) -> bool {
70+
false
71+
}
6672
}
6773

6874
impl dyn Dialect {

src/dialect/snowflake.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
2828
|| ch == '$'
2929
|| ch == '_'
3030
}
31+
32+
fn supports_within_after_array_aggregation(&self) -> bool {
33+
true
34+
}
3135
}

src/parser.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ impl<'a> Parser<'a> {
476476
self.expect_token(&Token::LParen)?;
477477
self.parse_array_subquery()
478478
}
479+
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
479480
Keyword::NOT => Ok(Expr::UnaryOp {
480481
op: UnaryOperator::Not,
481482
expr: Box::new(self.parse_subexpr(Self::UNARY_NOT_PREC)?),
@@ -1006,6 +1007,54 @@ impl<'a> Parser<'a> {
10061007
}))
10071008
}
10081009

1010+
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
1011+
self.expect_token(&Token::LParen)?;
1012+
let distinct = self.parse_keyword(Keyword::DISTINCT);
1013+
let expr = Box::new(self.parse_expr()?);
1014+
// ANSI SQL and BigQuery define ORDER BY inside function.
1015+
if !self.dialect.supports_within_after_array_aggregation() {
1016+
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
1017+
let order_by_expr = self.parse_order_by_expr()?;
1018+
Some(Box::new(order_by_expr))
1019+
} else {
1020+
None
1021+
};
1022+
let limit = if self.parse_keyword(Keyword::LIMIT) {
1023+
self.parse_limit()?.map(Box::new)
1024+
} else {
1025+
None
1026+
};
1027+
self.expect_token(&Token::RParen)?;
1028+
return Ok(Expr::ArrayAgg(ArrayAgg {
1029+
distinct,
1030+
expr,
1031+
order_by,
1032+
limit,
1033+
within_group: false,
1034+
}));
1035+
}
1036+
// Snowflake defines ORDERY BY in within group instead of inside the function like
1037+
// ANSI SQL.
1038+
self.expect_token(&Token::RParen)?;
1039+
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
1040+
self.expect_token(&Token::LParen)?;
1041+
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
1042+
let order_by_expr = self.parse_order_by_expr()?;
1043+
self.expect_token(&Token::RParen)?;
1044+
Some(Box::new(order_by_expr))
1045+
} else {
1046+
None
1047+
};
1048+
1049+
Ok(Expr::ArrayAgg(ArrayAgg {
1050+
distinct,
1051+
expr,
1052+
order_by: within_group,
1053+
limit: None,
1054+
within_group: true,
1055+
}))
1056+
}
1057+
10091058
// This function parses date/time fields for both the EXTRACT function-like
10101059
// operator and interval qualifiers. EXTRACT supports a wider set of
10111060
// date/time fields than interval qualifiers, so this function may need to

tests/sqlparser_common.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ mod test_utils;
2424
use matches::assert_matches;
2525
use sqlparser::ast::*;
2626
use sqlparser::dialect::{
27-
AnsiDialect, ClickHouseDialect, GenericDialect, MsSqlDialect, PostgreSqlDialect, SQLiteDialect,
28-
SnowflakeDialect,
27+
AnsiDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect, PostgreSqlDialect,
28+
SQLiteDialect, SnowflakeDialect,
2929
};
3030
use sqlparser::keywords::ALL_KEYWORDS;
3131
use sqlparser::parser::{Parser, ParserError};
@@ -1644,6 +1644,27 @@ fn parse_listagg() {
16441644
);
16451645
}
16461646

1647+
#[test]
1648+
fn parse_array_agg_func() {
1649+
let supported_dialects = TestedDialects {
1650+
dialects: vec![
1651+
Box::new(GenericDialect {}),
1652+
Box::new(PostgreSqlDialect {}),
1653+
Box::new(MsSqlDialect {}),
1654+
Box::new(AnsiDialect {}),
1655+
Box::new(HiveDialect {}),
1656+
],
1657+
};
1658+
1659+
for sql in [
1660+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
1661+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
1662+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
1663+
] {
1664+
supported_dialects.verified_stmt(sql);
1665+
}
1666+
}
1667+
16471668
#[test]
16481669
fn parse_create_table() {
16491670
let sql = "CREATE TABLE uk_cities (\

tests/sqlparser_snowflake.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ fn test_single_table_in_parenthesis_with_alias() {
144144
);
145145
}
146146

147+
#[test]
148+
fn test_array_agg_func() {
149+
for sql in [
150+
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
151+
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
152+
] {
153+
snowflake().verified_stmt(sql);
154+
}
155+
156+
let sql = "select array_agg(x order by x) as a from T";
157+
let result = snowflake().parse_sql_statements(sql);
158+
assert_eq!(
159+
result,
160+
Err(ParserError::ParserError(String::from(
161+
"Expected ), found: order"
162+
)))
163+
)
164+
}
165+
147166
fn snowflake() -> TestedDialects {
148167
TestedDialects {
149168
dialects: vec![Box::new(SnowflakeDialect {})],

0 commit comments

Comments
 (0)