@@ -2447,6 +2447,92 @@ fn parse_select_group_by_all() {
2447
2447
) ;
2448
2448
}
2449
2449
2450
+ #[ test]
2451
+ fn parse_group_by_with_modifier ( ) {
2452
+ let clauses = [ "x" , "a, b" , "ALL" ] ;
2453
+ let modifiers = [
2454
+ "WITH ROLLUP" ,
2455
+ "WITH CUBE" ,
2456
+ "WITH TOTALS" ,
2457
+ "WITH ROLLUP WITH CUBE" ,
2458
+ ] ;
2459
+ let expected_modifiers = [
2460
+ vec ! [ GroupByWithModifier :: Rollup ] ,
2461
+ vec ! [ GroupByWithModifier :: Cube ] ,
2462
+ vec ! [ GroupByWithModifier :: Totals ] ,
2463
+ vec ! [ GroupByWithModifier :: Rollup , GroupByWithModifier :: Cube ] ,
2464
+ ] ;
2465
+ let dialects = all_dialects_where ( |d| d. supports_group_by_with_modifier ( ) ) ;
2466
+
2467
+ for clause in & clauses {
2468
+ for ( modifier, expected_modifier) in modifiers. iter ( ) . zip ( expected_modifiers. iter ( ) ) {
2469
+ let sql = format ! ( "SELECT * FROM t GROUP BY {clause} {modifier}" ) ;
2470
+ match dialects. verified_stmt ( & sql) {
2471
+ Statement :: Query ( query) => {
2472
+ let group_by = & query. body . as_select ( ) . unwrap ( ) . group_by ;
2473
+ if clause == & "ALL" {
2474
+ assert_eq ! ( group_by, & GroupByExpr :: All ( expected_modifier. to_vec( ) ) ) ;
2475
+ } else {
2476
+ assert_eq ! (
2477
+ group_by,
2478
+ & GroupByExpr :: Expressions (
2479
+ clause
2480
+ . split( ", " )
2481
+ . map( |c| Identifier ( Ident :: new( c) ) )
2482
+ . collect( ) ,
2483
+ expected_modifier. to_vec( )
2484
+ )
2485
+ ) ;
2486
+ }
2487
+ }
2488
+ _ => unreachable ! ( ) ,
2489
+ }
2490
+ }
2491
+ }
2492
+
2493
+ // invalid cases
2494
+ let invalid_cases = [
2495
+ "SELECT * FROM t GROUP BY x WITH" ,
2496
+ "SELECT * FROM t GROUP BY x WITH ROLLUP CUBE" ,
2497
+ "SELECT * FROM t GROUP BY x WITH WITH ROLLUP" ,
2498
+ "SELECT * FROM t GROUP BY WITH ROLLUP" ,
2499
+ ] ;
2500
+ for sql in invalid_cases {
2501
+ dialects
2502
+ . parse_sql_statements ( sql)
2503
+ . expect_err ( "Expected: one of ROLLUP or CUBE or TOTALS, found: WITH" ) ;
2504
+ }
2505
+ }
2506
+
2507
+ #[ test]
2508
+ fn parse_group_by_special_grouping_sets ( ) {
2509
+ let sql = "SELECT a, b, SUM(c) FROM tab1 GROUP BY a, b GROUPING SETS ((a, b), (a), (b), ())" ;
2510
+ match all_dialects ( ) . verified_stmt ( sql) {
2511
+ Statement :: Query ( query) => {
2512
+ let group_by = & query. body . as_select ( ) . unwrap ( ) . group_by ;
2513
+ assert_eq ! (
2514
+ group_by,
2515
+ & GroupByExpr :: Expressions (
2516
+ vec![
2517
+ Expr :: Identifier ( Ident :: new( "a" ) ) ,
2518
+ Expr :: Identifier ( Ident :: new( "b" ) )
2519
+ ] ,
2520
+ vec![ GroupByWithModifier :: GroupingSets ( Expr :: GroupingSets ( vec![
2521
+ vec![
2522
+ Expr :: Identifier ( Ident :: new( "a" ) ) ,
2523
+ Expr :: Identifier ( Ident :: new( "b" ) )
2524
+ ] ,
2525
+ vec![ Expr :: Identifier ( Ident :: new( "a" ) ) , ] ,
2526
+ vec![ Expr :: Identifier ( Ident :: new( "b" ) ) ] ,
2527
+ vec![ ]
2528
+ ] ) ) ]
2529
+ )
2530
+ ) ;
2531
+ }
2532
+ _ => unreachable ! ( ) ,
2533
+ }
2534
+ }
2535
+
2450
2536
#[ test]
2451
2537
fn parse_select_having ( ) {
2452
2538
let sql = "SELECT foo FROM bar GROUP BY foo HAVING COUNT(*) > 1" ;
0 commit comments