Skip to content

Commit 64efa11

Browse files
committed
Parse ARRAY_AGG for Bigquery and Snowflake
1 parent 1b3778e commit 64efa11

File tree

8 files changed

+154
-4
lines changed

8 files changed

+154
-4
lines changed

src/ast/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,8 @@ pub enum Expr {
411411
ArraySubquery(Box<Query>),
412412
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
413413
ListAgg(ListAgg),
414+
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
415+
ArrayAgg(ArrayAgg),
414416
/// The `GROUPING SETS` expr.
415417
GroupingSets(Vec<Vec<Expr>>),
416418
/// The `CUBE` expr.
@@ -650,6 +652,7 @@ impl fmt::Display for Expr {
650652
Expr::Subquery(s) => write!(f, "({})", s),
651653
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
652654
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
655+
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
653656
Expr::GroupingSets(sets) => {
654657
write!(f, "GROUPING SETS (")?;
655658
let mut sep = "";
@@ -2927,6 +2930,45 @@ impl fmt::Display for ListAggOnOverflow {
29272930
}
29282931
}
29292932

2933+
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
2934+
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
2935+
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
2936+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2937+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2938+
pub struct ArrayAgg {
2939+
pub distinct: bool,
2940+
pub expr: Box<Expr>,
2941+
pub order_by: Option<Box<OrderByExpr>>,
2942+
pub limit: Option<Box<Expr>>,
2943+
pub within_group: bool, // order by is used inside a within group or not
2944+
}
2945+
2946+
impl fmt::Display for ArrayAgg {
2947+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2948+
write!(
2949+
f,
2950+
"ARRAY_AGG({}{}",
2951+
if self.distinct { "DISTINCT " } else { "" },
2952+
self.expr
2953+
)?;
2954+
if !self.within_group {
2955+
if let Some(order_by) = &self.order_by {
2956+
write!(f, " ORDER BY {}", order_by)?;
2957+
}
2958+
if let Some(limit) = &self.limit {
2959+
write!(f, " LIMIT {}", limit)?;
2960+
}
2961+
}
2962+
write!(f, ")")?;
2963+
if self.within_group {
2964+
if let Some(order_by) = &self.order_by {
2965+
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
2966+
}
2967+
}
2968+
Ok(())
2969+
}
2970+
}
2971+
29302972
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29312973
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29322974
pub enum ObjectType {

src/dialect/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
7171
fn supports_filter_during_aggregation(&self) -> bool {
7272
false
7373
}
74+
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
75+
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
76+
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
77+
fn supports_within_after_array_aggregation(&self) -> bool {
78+
false
79+
}
7480
/// Dialect-specific prefix parser override
7581
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
7682
// return None to fall back to the default behavior

src/dialect/snowflake.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
2828
|| ch == '$'
2929
|| ch == '_'
3030
}
31+
32+
fn supports_within_after_array_aggregation(&self) -> bool {
33+
true
34+
}
3135
}

src/parser.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
473473
self.expect_token(&Token::LParen)?;
474474
self.parse_array_subquery()
475475
}
476+
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
476477
Keyword::NOT => self.parse_not(),
477478
// Here `w` is a word, check if it's a part of a multi-part
478479
// identifier, a function call, or a simple identifier:
@@ -1071,6 +1072,57 @@ impl<'a> Parser<'a> {
10711072
}))
10721073
}
10731074

1075+
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
1076+
self.expect_token(&Token::LParen)?;
1077+
let distinct = self.parse_keyword(Keyword::DISTINCT);
1078+
let expr = Box::new(self.parse_expr()?);
1079+
// ANSI SQL and BigQuery define ORDER BY inside function.
1080+
if !self.dialect.supports_within_after_array_aggregation() {
1081+
let order_by = if self.parse_keywords(&[Keyword:: ORDER, Keyword::BY]) {
1082+
let order_by_expr = self.parse_order_by_expr()?;
1083+
Some(Box::new(order_by_expr))
1084+
} else {
1085+
None
1086+
};
1087+
let limit = if self.parse_keyword(Keyword::LIMIT) {
1088+
match self.parse_limit()? {
1089+
Some(expr) => Some(Box::new(expr)),
1090+
None => None
1091+
}
1092+
} else {
1093+
None
1094+
};
1095+
self.expect_token(&Token::RParen)?;
1096+
return Ok(Expr::ArrayAgg(ArrayAgg {
1097+
distinct,
1098+
expr,
1099+
order_by,
1100+
limit,
1101+
within_group: false,
1102+
}));
1103+
}
1104+
// Snowflake defines ORDERY BY in within group instead of inside the function like
1105+
// ANSI SQL.
1106+
self.expect_token(&Token::RParen)?;
1107+
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
1108+
self.expect_token(&Token::LParen)?;
1109+
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
1110+
let order_by_expr = self.parse_order_by_expr()?;
1111+
self.expect_token(&Token::RParen)?;
1112+
Some(Box::new(order_by_expr))
1113+
} else {
1114+
None
1115+
};
1116+
1117+
Ok(Expr::ArrayAgg(ArrayAgg {
1118+
distinct,
1119+
expr,
1120+
order_by: within_group,
1121+
limit: None,
1122+
within_group: true,
1123+
}))
1124+
}
1125+
10741126
// This function parses date/time fields for the EXTRACT function-like
10751127
// operator, interval qualifiers, and the ceil/floor operations.
10761128
// EXTRACT supports a wider set of date/time fields than interval qualifiers,

tests/sqlparser_bigquery.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,17 @@ fn parse_cast_type() {
115115
bigquery().verified_only_select(sql);
116116
}
117117

118+
#[test]
119+
fn parse_array_agg_func() {
120+
for sql in [
121+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
122+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
123+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl"
124+
] {
125+
bigquery().verified_stmt(sql);
126+
}
127+
}
128+
118129
fn bigquery() -> TestedDialects {
119130
TestedDialects {
120131
dialects: vec![Box::new(BigQueryDialect {})],

tests/sqlparser_common.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,27 @@ fn parse_listagg() {
18861886
);
18871887
}
18881888

1889+
#[test]
1890+
fn parse_array_agg_func() {
1891+
let supported_dialects = TestedDialects {
1892+
dialects: vec![
1893+
Box::new(GenericDialect {}),
1894+
Box::new(PostgreSqlDialect {}),
1895+
Box::new(MsSqlDialect {}),
1896+
Box::new(AnsiDialect {}),
1897+
Box::new(HiveDialect {}),
1898+
]
1899+
};
1900+
1901+
for sql in [
1902+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
1903+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
1904+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
1905+
] {
1906+
supported_dialects.verified_stmt(sql);
1907+
}
1908+
}
1909+
18891910
#[test]
18901911
fn parse_create_table() {
18911912
let sql = "CREATE TABLE uk_cities (\

tests/sqlparser_hive.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,17 @@ fn parse_create_function() {
278278
#[test]
279279
fn filtering_during_aggregation() {
280280
let rename = "SELECT \
281-
array_agg(name) FILTER (WHERE name IS NOT NULL), \
282-
array_agg(name) FILTER (WHERE name LIKE 'a%') \
281+
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
282+
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
283283
FROM region";
284284
println!("{}", hive().verified_stmt(rename));
285285
}
286286

287287
#[test]
288288
fn filtering_during_aggregation_aliased() {
289289
let rename = "SELECT \
290-
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
291-
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
290+
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
291+
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
292292
FROM region";
293293
println!("{}", hive().verified_stmt(rename));
294294
}

tests/sqlparser_snowflake.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,20 @@ fn test_single_table_in_parenthesis_with_alias() {
143143
);
144144
}
145145

146+
#[test]
147+
fn test_array_agg_func() {
148+
for sql in [
149+
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
150+
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl"
151+
] {
152+
snowflake().verified_stmt(sql);
153+
}
154+
155+
let sql = "select array_agg(x order by x) as a from T";
156+
let result = snowflake().parse_sql_statements(&sql);
157+
assert_eq!(result, Err(ParserError::ParserError(String::from("Expected ), found: order"))))
158+
}
159+
146160
fn snowflake() -> TestedDialects {
147161
TestedDialects {
148162
dialects: vec![Box::new(SnowflakeDialect {})],

0 commit comments

Comments
 (0)