Skip to content

Commit 4fcaaeb

Browse files
SuperBomcheshkov
authored andcommitted
Parse ARRAY_AGG for Bigquery and Snowflake (apache#662)
Can drop this after rebase on commit 87b4a16 "Parse ARRAY_AGG for Bigquery and Snowflake (apache#662)", first released in 0.27.0
1 parent 05bf229 commit 4fcaaeb

File tree

6 files changed

+143
-2
lines changed

6 files changed

+143
-2
lines changed

src/ast/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ pub enum Expr {
386386
Subquery(Box<Query>),
387387
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
388388
ListAgg(ListAgg),
389+
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
390+
ArrayAgg(ArrayAgg),
389391
/// The `GROUPING SETS` expr.
390392
GroupingSets(Vec<Vec<Expr>>),
391393
/// The `CUBE` expr.
@@ -580,6 +582,7 @@ impl fmt::Display for Expr {
580582
Expr::Subquery(s) => write!(f, "({})", s),
581583
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
582584
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
585+
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
583586
Expr::GroupingSets(sets) => {
584587
write!(f, "GROUPING SETS (")?;
585588
let mut sep = "";
@@ -2491,6 +2494,45 @@ impl fmt::Display for ListAggOnOverflow {
24912494
}
24922495
}
24932496

2497+
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
2498+
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
2499+
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
2500+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2501+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2502+
pub struct ArrayAgg {
2503+
pub distinct: bool,
2504+
pub expr: Box<Expr>,
2505+
pub order_by: Option<Box<OrderByExpr>>,
2506+
pub limit: Option<Box<Expr>>,
2507+
pub within_group: bool, // order by is used inside a within group or not
2508+
}
2509+
2510+
impl fmt::Display for ArrayAgg {
2511+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2512+
write!(
2513+
f,
2514+
"ARRAY_AGG({}{}",
2515+
if self.distinct { "DISTINCT " } else { "" },
2516+
self.expr
2517+
)?;
2518+
if !self.within_group {
2519+
if let Some(order_by) = &self.order_by {
2520+
write!(f, " ORDER BY {}", order_by)?;
2521+
}
2522+
if let Some(limit) = &self.limit {
2523+
write!(f, " LIMIT {}", limit)?;
2524+
}
2525+
}
2526+
write!(f, ")")?;
2527+
if self.within_group {
2528+
if let Some(order_by) = &self.order_by {
2529+
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
2530+
}
2531+
}
2532+
Ok(())
2533+
}
2534+
}
2535+
24942536
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24952537
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24962538
pub enum ObjectType {

src/dialect/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ pub trait Dialect: Debug + Any {
6363
fn is_identifier_start(&self, ch: char) -> bool;
6464
/// Determine if a character is a valid unquoted identifier character
6565
fn is_identifier_part(&self, ch: char) -> bool;
66+
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
67+
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
68+
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
69+
fn supports_within_after_array_aggregation(&self) -> bool {
70+
false
71+
}
6672
}
6773

6874
impl dyn Dialect {

src/dialect/snowflake.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
2828
|| ch == '$'
2929
|| ch == '_'
3030
}
31+
32+
fn supports_within_after_array_aggregation(&self) -> bool {
33+
true
34+
}
3135
}

src/parser.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ impl<'a> Parser<'a> {
464464
self.expect_token(&Token::LParen)?;
465465
self.parse_array_subquery()
466466
}
467+
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
467468
Keyword::NOT => Ok(Expr::UnaryOp {
468469
op: UnaryOperator::Not,
469470
expr: Box::new(self.parse_subexpr(Self::UNARY_NOT_PREC)?),
@@ -991,6 +992,54 @@ impl<'a> Parser<'a> {
991992
}))
992993
}
993994

995+
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
996+
self.expect_token(&Token::LParen)?;
997+
let distinct = self.parse_keyword(Keyword::DISTINCT);
998+
let expr = Box::new(self.parse_expr()?);
999+
// ANSI SQL and BigQuery define ORDER BY inside function.
1000+
if !self.dialect.supports_within_after_array_aggregation() {
1001+
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
1002+
let order_by_expr = self.parse_order_by_expr()?;
1003+
Some(Box::new(order_by_expr))
1004+
} else {
1005+
None
1006+
};
1007+
let limit = if self.parse_keyword(Keyword::LIMIT) {
1008+
self.parse_limit()?.map(Box::new)
1009+
} else {
1010+
None
1011+
};
1012+
self.expect_token(&Token::RParen)?;
1013+
return Ok(Expr::ArrayAgg(ArrayAgg {
1014+
distinct,
1015+
expr,
1016+
order_by,
1017+
limit,
1018+
within_group: false,
1019+
}));
1020+
}
1021+
// Snowflake defines ORDERY BY in within group instead of inside the function like
1022+
// ANSI SQL.
1023+
self.expect_token(&Token::RParen)?;
1024+
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
1025+
self.expect_token(&Token::LParen)?;
1026+
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
1027+
let order_by_expr = self.parse_order_by_expr()?;
1028+
self.expect_token(&Token::RParen)?;
1029+
Some(Box::new(order_by_expr))
1030+
} else {
1031+
None
1032+
};
1033+
1034+
Ok(Expr::ArrayAgg(ArrayAgg {
1035+
distinct,
1036+
expr,
1037+
order_by: within_group,
1038+
limit: None,
1039+
within_group: true,
1040+
}))
1041+
}
1042+
9941043
// This function parses date/time fields for both the EXTRACT function-like
9951044
// operator and interval qualifiers. EXTRACT supports a wider set of
9961045
// date/time fields than interval qualifiers, so this function may need to

tests/sqlparser_common.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ mod test_utils;
2424
use matches::assert_matches;
2525
use sqlparser::ast::*;
2626
use sqlparser::dialect::{
27-
AnsiDialect, ClickHouseDialect, GenericDialect, MsSqlDialect, PostgreSqlDialect, SQLiteDialect,
28-
SnowflakeDialect,
27+
AnsiDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect, PostgreSqlDialect,
28+
SQLiteDialect, SnowflakeDialect,
2929
};
3030
use sqlparser::keywords::ALL_KEYWORDS;
3131
use sqlparser::parser::{Parser, ParserError};
@@ -1695,6 +1695,27 @@ fn parse_listagg() {
16951695
);
16961696
}
16971697

1698+
#[test]
1699+
fn parse_array_agg_func() {
1700+
let supported_dialects = TestedDialects {
1701+
dialects: vec![
1702+
Box::new(GenericDialect {}),
1703+
Box::new(PostgreSqlDialect {}),
1704+
Box::new(MsSqlDialect {}),
1705+
Box::new(AnsiDialect {}),
1706+
Box::new(HiveDialect {}),
1707+
],
1708+
};
1709+
1710+
for sql in [
1711+
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
1712+
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
1713+
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
1714+
] {
1715+
supported_dialects.verified_stmt(sql);
1716+
}
1717+
}
1718+
16981719
#[test]
16991720
fn parse_create_table() {
17001721
let sql = "CREATE TABLE uk_cities (\

tests/sqlparser_snowflake.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ fn test_single_table_in_parenthesis_with_alias() {
144144
);
145145
}
146146

147+
#[test]
148+
fn test_array_agg_func() {
149+
for sql in [
150+
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
151+
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
152+
] {
153+
snowflake().verified_stmt(sql);
154+
}
155+
156+
let sql = "select array_agg(x order by x) as a from T";
157+
let result = snowflake().parse_sql_statements(sql);
158+
assert_eq!(
159+
result,
160+
Err(ParserError::ParserError(String::from(
161+
"Expected ), found: order"
162+
)))
163+
)
164+
}
165+
147166
fn snowflake() -> TestedDialects {
148167
TestedDialects {
149168
dialects: vec![Box::new(SnowflakeDialect {})],

0 commit comments

Comments
 (0)