Skip to content

Commit 87b4a16

Browse files
authored
Parse ARRAY_AGG for Bigquery and Snowflake (apache#662)
1 parent 0428ac7 commit 87b4a16

File tree

8 files changed

+156
-4
lines changed

8 files changed

+156
-4
lines changed

src/ast/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ pub enum Expr {
416416
ArraySubquery(Box<Query>),
417417
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
418418
ListAgg(ListAgg),
419+
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
420+
ArrayAgg(ArrayAgg),
419421
/// The `GROUPING SETS` expr.
420422
GroupingSets(Vec<Vec<Expr>>),
421423
/// The `CUBE` expr.
@@ -655,6 +657,7 @@ impl fmt::Display for Expr {
655657
Expr::Subquery(s) => write!(f, "({})", s),
656658
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
657659
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
660+
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
658661
Expr::GroupingSets(sets) => {
659662
write!(f, "GROUPING SETS (")?;
660663
let mut sep = "";
@@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
30363039
}
30373040
}
30383041

3042+
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
3043+
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
3044+
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
3045+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
3046+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
3047+
pub struct ArrayAgg {
3048+
pub distinct: bool,
3049+
pub expr: Box<Expr>,
3050+
pub order_by: Option<Box<OrderByExpr>>,
3051+
pub limit: Option<Box<Expr>>,
3052+
pub within_group: bool, // order by is used inside a within group or not
3053+
}
3054+
3055+
impl fmt::Display for ArrayAgg {
3056+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3057+
write!(
3058+
f,
3059+
"ARRAY_AGG({}{}",
3060+
if self.distinct { "DISTINCT " } else { "" },
3061+
self.expr
3062+
)?;
3063+
if !self.within_group {
3064+
if let Some(order_by) = &self.order_by {
3065+
write!(f, " ORDER BY {}", order_by)?;
3066+
}
3067+
if let Some(limit) = &self.limit {
3068+
write!(f, " LIMIT {}", limit)?;
3069+
}
3070+
}
3071+
write!(f, ")")?;
3072+
if self.within_group {
3073+
if let Some(order_by) = &self.order_by {
3074+
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
3075+
}
3076+
}
3077+
Ok(())
3078+
}
3079+
}
3080+
30393081
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30403082
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30413083
pub enum ObjectType {

src/dialect/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
7171
fn supports_filter_during_aggregation(&self) -> bool {
7272
false
7373
}
74+
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
75+
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
76+
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
77+
fn supports_within_after_array_aggregation(&self) -> bool {
78+
false
79+
}
7480
/// Dialect-specific prefix parser override
7581
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
7682
// return None to fall back to the default behavior

src/dialect/snowflake.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
2828
|| ch == '$'
2929
|| ch == '_'
3030
}
31+
32+
fn supports_within_after_array_aggregation(&self) -> bool {
33+
true
34+
}
3135
}

src/parser.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
473473
self.expect_token(&Token::LParen)?;
474474
self.parse_array_subquery()
475475
}
476+
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
476477
Keyword::NOT => self.parse_not(),
477478
// Here `w` is a word, check if it's a part of a multi-part
478479
// identifier, a function call, or a simple identifier:
@@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
10711072
}))
10721073
}
10731074

1075+
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
1076+
self.expect_token(&Token::LParen)?;
1077+
let distinct = self.parse_keyword(Keyword::DISTINCT);
1078+
let expr = Box::new(self.parse_expr()?);
1079+
// ANSI SQL and BigQuery define ORDER BY inside function.
1080+
if !self.dialect.supports_within_after_array_aggregation() {
1081+
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
1082+
let order_by_expr = self.parse_order_by_expr()?;
1083+
Some(Box::new(order_by_expr))
1084+
} else {
1085+
None
1086+
};
1087+
let limit = if self.parse_keyword(Keyword::LIMIT) {
1088+
self.parse_limit()?.map(Box::new)
1089+
} else {
1090+
None
1091+
};
1092+
self.expect_token(&Token::RParen)?;
1093+
return Ok(Expr::ArrayAgg(ArrayAgg {
1094+
distinct,
1095+
expr,
1096+
order_by,
1097+
limit,
1098+
within_group: false,
1099+
}));
1100+
}
1101+
// Snowflake defines ORDERY BY in within group instead of inside the function like
1102+
// ANSI SQL.
1103+
self.expect_token(&Token::RParen)?;
1104+
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
1105+
self.expect_token(&Token::LParen)?;
1106+
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
1107+
let order_by_expr = self.parse_order_by_expr()?;
1108+
self.expect_token(&Token::RParen)?;
1109+
Some(Box::new(order_by_expr))
1110+
} else {
1111+
None
1112+
};
1113+
1114+
Ok(Expr::ArrayAgg(ArrayAgg {
1115+
distinct,
1116+
expr,
1117+
order_by: within_group,
1118+
limit: None,
1119+
within_group: true,
1120+
}))
1121+
}
1122+
10741123
// This function parses date/time fields for the EXTRACT function-like
10751124
// operator, interval qualifiers, and the ceil/floor operations.
10761125
// EXTRACT supports a wider set of date/time fields than interval qualifiers,

tests/sqlparser_bigquery.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,17 @@ fn parse_similar_to() {
224224
chk(true);
225225
}
226226

227+
#[test]
228+
fn parse_array_agg_func() {
229+
for sql in [
230+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
231+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
232+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
233+
] {
234+
bigquery().verified_stmt(sql);
235+
}
236+
}
237+
227238
fn bigquery() -> TestedDialects {
228239
TestedDialects {
229240
dialects: vec![Box::new(BigQueryDialect {})],

tests/sqlparser_common.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,27 @@ fn parse_listagg() {
17771777
);
17781778
}
17791779

1780+
#[test]
1781+
fn parse_array_agg_func() {
1782+
let supported_dialects = TestedDialects {
1783+
dialects: vec![
1784+
Box::new(GenericDialect {}),
1785+
Box::new(PostgreSqlDialect {}),
1786+
Box::new(MsSqlDialect {}),
1787+
Box::new(AnsiDialect {}),
1788+
Box::new(HiveDialect {}),
1789+
],
1790+
};
1791+
1792+
for sql in [
1793+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
1794+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
1795+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
1796+
] {
1797+
supported_dialects.verified_stmt(sql);
1798+
}
1799+
}
1800+
17801801
#[test]
17811802
fn parse_create_table() {
17821803
let sql = "CREATE TABLE uk_cities (\

tests/sqlparser_hive.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,17 +281,17 @@ fn parse_create_function() {
281281
#[test]
282282
fn filtering_during_aggregation() {
283283
let rename = "SELECT \
284-
array_agg(name) FILTER (WHERE name IS NOT NULL), \
285-
array_agg(name) FILTER (WHERE name LIKE 'a%') \
284+
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
285+
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
286286
FROM region";
287287
println!("{}", hive().verified_stmt(rename));
288288
}
289289

290290
#[test]
291291
fn filtering_during_aggregation_aliased() {
292292
let rename = "SELECT \
293-
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
294-
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
293+
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
294+
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
295295
FROM region";
296296
println!("{}", hive().verified_stmt(rename));
297297
}

tests/sqlparser_snowflake.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,25 @@ fn parse_similar_to() {
334334
chk(true);
335335
}
336336

337+
#[test]
338+
fn test_array_agg_func() {
339+
for sql in [
340+
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
341+
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
342+
] {
343+
snowflake().verified_stmt(sql);
344+
}
345+
346+
let sql = "select array_agg(x order by x) as a from T";
347+
let result = snowflake().parse_sql_statements(sql);
348+
assert_eq!(
349+
result,
350+
Err(ParserError::ParserError(String::from(
351+
"Expected ), found: order"
352+
)))
353+
)
354+
}
355+
337356
fn snowflake() -> TestedDialects {
338357
TestedDialects {
339358
dialects: vec![Box::new(SnowflakeDialect {})],

0 commit comments

Comments
 (0)