@@ -2448,6 +2448,92 @@ fn parse_select_group_by_all() {
2448
2448
) ;
2449
2449
}
2450
2450
2451
+ #[ test]
2452
+ fn parse_group_by_with_modifier ( ) {
2453
+ let clauses = [ "x" , "a, b" , "ALL" ] ;
2454
+ let modifiers = [
2455
+ "WITH ROLLUP" ,
2456
+ "WITH CUBE" ,
2457
+ "WITH TOTALS" ,
2458
+ "WITH ROLLUP WITH CUBE" ,
2459
+ ] ;
2460
+ let expected_modifiers = [
2461
+ vec ! [ GroupByWithModifier :: Rollup ] ,
2462
+ vec ! [ GroupByWithModifier :: Cube ] ,
2463
+ vec ! [ GroupByWithModifier :: Totals ] ,
2464
+ vec ! [ GroupByWithModifier :: Rollup , GroupByWithModifier :: Cube ] ,
2465
+ ] ;
2466
+ let dialects = all_dialects_where ( |d| d. supports_group_by_with_modifier ( ) ) ;
2467
+
2468
+ for clause in & clauses {
2469
+ for ( modifier, expected_modifier) in modifiers. iter ( ) . zip ( expected_modifiers. iter ( ) ) {
2470
+ let sql = format ! ( "SELECT * FROM t GROUP BY {clause} {modifier}" ) ;
2471
+ match dialects. verified_stmt ( & sql) {
2472
+ Statement :: Query ( query) => {
2473
+ let group_by = & query. body . as_select ( ) . unwrap ( ) . group_by ;
2474
+ if clause == & "ALL" {
2475
+ assert_eq ! ( group_by, & GroupByExpr :: All ( expected_modifier. to_vec( ) ) ) ;
2476
+ } else {
2477
+ assert_eq ! (
2478
+ group_by,
2479
+ & GroupByExpr :: Expressions (
2480
+ clause
2481
+ . split( ", " )
2482
+ . map( |c| Identifier ( Ident :: new( c) ) )
2483
+ . collect( ) ,
2484
+ expected_modifier. to_vec( )
2485
+ )
2486
+ ) ;
2487
+ }
2488
+ }
2489
+ _ => unreachable ! ( ) ,
2490
+ }
2491
+ }
2492
+ }
2493
+
2494
+ // invalid cases
2495
+ let invalid_cases = [
2496
+ "SELECT * FROM t GROUP BY x WITH" ,
2497
+ "SELECT * FROM t GROUP BY x WITH ROLLUP CUBE" ,
2498
+ "SELECT * FROM t GROUP BY x WITH WITH ROLLUP" ,
2499
+ "SELECT * FROM t GROUP BY WITH ROLLUP" ,
2500
+ ] ;
2501
+ for sql in invalid_cases {
2502
+ dialects
2503
+ . parse_sql_statements ( sql)
2504
+ . expect_err ( "Expected: one of ROLLUP or CUBE or TOTALS, found: WITH" ) ;
2505
+ }
2506
+ }
2507
+
2508
+ #[ test]
2509
+ fn parse_group_by_special_grouping_sets ( ) {
2510
+ let sql = "SELECT a, b, SUM(c) FROM tab1 GROUP BY a, b GROUPING SETS ((a, b), (a), (b), ())" ;
2511
+ match all_dialects ( ) . verified_stmt ( sql) {
2512
+ Statement :: Query ( query) => {
2513
+ let group_by = & query. body . as_select ( ) . unwrap ( ) . group_by ;
2514
+ assert_eq ! (
2515
+ group_by,
2516
+ & GroupByExpr :: Expressions (
2517
+ vec![
2518
+ Expr :: Identifier ( Ident :: new( "a" ) ) ,
2519
+ Expr :: Identifier ( Ident :: new( "b" ) )
2520
+ ] ,
2521
+ vec![ GroupByWithModifier :: GroupingSets ( Expr :: GroupingSets ( vec![
2522
+ vec![
2523
+ Expr :: Identifier ( Ident :: new( "a" ) ) ,
2524
+ Expr :: Identifier ( Ident :: new( "b" ) )
2525
+ ] ,
2526
+ vec![ Expr :: Identifier ( Ident :: new( "a" ) ) , ] ,
2527
+ vec![ Expr :: Identifier ( Ident :: new( "b" ) ) ] ,
2528
+ vec![ ]
2529
+ ] ) ) ]
2530
+ )
2531
+ ) ;
2532
+ }
2533
+ _ => unreachable ! ( ) ,
2534
+ }
2535
+ }
2536
+
2451
2537
#[ test]
2452
2538
fn parse_select_having ( ) {
2453
2539
let sql = "SELECT foo FROM bar GROUP BY foo HAVING COUNT(*) > 1" ;
0 commit comments