@@ -462,21 +462,21 @@ def test_types_unique() -> None:
462
462
463
463
464
464
def test_types_apply () -> None :
465
- df = pd .DataFrame (data = {"col1" : [2 , 1 ], "col2" : [3 , 4 ]})
465
+ df = pd .DataFrame (data = {"col1" : [1 , 2 ], "col2" : [3 , 4 ], "col3" : [5 , 6 ]})
466
+
467
+ def returns_scalar (x : pd .Series ) -> float :
468
+ return 2
466
469
467
470
def returns_series (x : pd .Series ) -> pd .Series :
468
471
return x ** 2
469
472
470
- check (assert_type (df .apply (returns_series ), pd .DataFrame ), pd .DataFrame )
473
+ def returns_listlike_of_2 (x : pd .Series ) -> tuple [int , int ]:
474
+ return (7 , 8 )
471
475
472
- def returns_scalar (x : pd .Series ) -> float :
473
- return 2
476
+ def returns_listlike_of_3 (x : pd .Series ) -> tuple [ int , int , int ] :
477
+ return ( 7 , 8 , 9 )
474
478
475
- check (assert_type (df .apply (returns_scalar ), pd .Series ), pd .Series )
476
- check (
477
- assert_type (df .apply (returns_scalar , result_type = "broadcast" ), pd .DataFrame ),
478
- pd .DataFrame ,
479
- )
479
+ # Misc checks
480
480
check (assert_type (df .apply (np .exp ), pd .DataFrame ), pd .DataFrame )
481
481
check (assert_type (df .apply (str ), pd .Series ), pd .Series )
482
482
@@ -486,12 +486,154 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
486
486
487
487
check (assert_type (df .apply (gethead , args = (4 ,)), pd .DataFrame ), pd .DataFrame )
488
488
489
- def returns_tuple (x : pd .Series ) -> tuple [str , str ]:
490
- return ("a" , "b" )
489
+ # Check scalar/series/list-like for default (None), and result type of expand, reduce, broadcast
490
+ check (
491
+ assert_type (df .apply (returns_scalar ), pd .Series ),
492
+ pd .Series ,
493
+ )
494
+ check (
495
+ assert_type (df .apply (returns_series ), pd .DataFrame ),
496
+ pd .DataFrame ,
497
+ )
498
+ check (
499
+ assert_type (df .apply (returns_listlike_of_3 ), pd .DataFrame ),
500
+ pd .DataFrame ,
501
+ )
502
+
503
+ # While this call works in reality, it errors in the type checker, because this should never be called
504
+ # It does not make sense to pass a result_type of "expand" to a scalar return
505
+ # check(
506
+ # assert_type(
507
+ # df.apply(returns_scalar, result_type="expand"), pd.DataFrame
508
+ # ),
509
+ # pd.DataFrame,
510
+ # )
511
+ check (
512
+ assert_type (df .apply (returns_series , result_type = "expand" ), pd .DataFrame ),
513
+ pd .DataFrame ,
514
+ )
515
+ check (
516
+ assert_type (
517
+ df .apply (returns_listlike_of_3 , result_type = "expand" ), pd .DataFrame
518
+ ),
519
+ pd .DataFrame ,
520
+ )
521
+
522
+ # While this call works in reality, it errors in the type checker, because this should never be called
523
+ # It does not make sense to pass a result_type of "reduce" to a scalar or series return
524
+ # check(
525
+ # assert_type(
526
+ # df.apply(returns_scalar, result_type="reduce"), pd.DataFrame
527
+ # ),
528
+ # pd.DataFrame,
529
+ # )
530
+ # check(
531
+ # assert_type(
532
+ # df.apply(returns_series, result_type="reduce"), pd.Series
533
+ # ),
534
+ # pd.Series,
535
+ # )
536
+ check (
537
+ assert_type (df .apply (returns_listlike_of_3 , result_type = "reduce" ), pd .Series ),
538
+ pd .Series ,
539
+ )
540
+
541
+ # While this call works in reality, it errors in the type checker, because this should never be called
542
+ # It does not make sense to pass a result_type of "broadcast" to a scalar return
543
+ # check(
544
+ # assert_type(
545
+ # df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame
546
+ # ),
547
+ # pd.DataFrame,
548
+ # )
549
+ check (
550
+ assert_type (df .apply (returns_series , result_type = "broadcast" ), pd .DataFrame ),
551
+ pd .DataFrame ,
552
+ )
553
+ check (
554
+ assert_type (
555
+ # Can only broadcast a list-like of 2 elements, not 3, because there are 2 rows
556
+ df .apply (returns_listlike_of_2 , result_type = "broadcast" ),
557
+ pd .DataFrame ,
558
+ ),
559
+ pd .DataFrame ,
560
+ )
561
+
562
+ # Check the same combinations with axis=1
563
+ check (
564
+ assert_type (df .apply (returns_scalar , axis = 1 ), pd .Series ),
565
+ pd .Series ,
566
+ )
567
+ check (
568
+ assert_type (df .apply (returns_series , axis = 1 ), pd .DataFrame ),
569
+ pd .DataFrame ,
570
+ )
571
+ check (
572
+ assert_type (df .apply (returns_listlike_of_3 , axis = 1 ), pd .DataFrame ),
573
+ pd .DataFrame ,
574
+ )
491
575
576
+ # While this call works in reality, it errors in the type checker, because this should never be called
577
+ # It does not make sense to pass a result_type of "expand" to a scalar return
578
+ # check(
579
+ # assert_type(
580
+ # df.apply(returns_scalar, axis=1, result_type="expand"), pd.DataFrame
581
+ # ),
582
+ # pd.DataFrame,
583
+ # )
584
+ check (
585
+ assert_type (
586
+ df .apply (returns_series , axis = 1 , result_type = "expand" ), pd .DataFrame
587
+ ),
588
+ pd .DataFrame ,
589
+ )
492
590
check (
493
591
assert_type (
494
- df .apply (returns_tuple , axis = 1 , result_type = "expand" ), pd .DataFrame
592
+ df .apply (returns_listlike_of_3 , axis = 1 , result_type = "expand" ), pd .DataFrame
593
+ ),
594
+ pd .DataFrame ,
595
+ )
596
+
597
+ # While this call works in reality, it errors in the type checker, because this should never be called
598
+ # It does not make sense to pass a result_type of "reduce" to a scalar or series return
599
+ # check(
600
+ # assert_type(
601
+ # df.apply(returns_scalar, axis=1, result_type="reduce"), pd.DataFrame
602
+ # ),
603
+ # pd.DataFrame,
604
+ # )
605
+ # check(
606
+ # assert_type(
607
+ # df.apply(returns_series, axis=1, result_type="reduce"), pd.DataFrame
608
+ # ),
609
+ # pd.DataFrame,
610
+ # )
611
+ check (
612
+ assert_type (
613
+ df .apply (returns_listlike_of_3 , axis = 1 , result_type = "reduce" ), pd .Series
614
+ ),
615
+ pd .Series ,
616
+ )
617
+
618
+ # While this call works in reality, it errors in the type checker, because this should never be called
619
+ # It does not make sense to pass a result_type of "broadcast" to a scalar return
620
+ # check(
621
+ # assert_type(
622
+ # df.apply(returns_scalar, axis=1, result_type="broadcast"), pd.DataFrame
623
+ # ),
624
+ # pd.DataFrame,
625
+ # )
626
+ check (
627
+ assert_type (
628
+ df .apply (returns_series , axis = 1 , result_type = "broadcast" ), pd .DataFrame
629
+ ),
630
+ pd .DataFrame ,
631
+ )
632
+ check (
633
+ assert_type (
634
+ # Can only broadcast a list-like of 3 elements, not 2, as there are 3 columns
635
+ df .apply (returns_listlike_of_3 , axis = 1 , result_type = "broadcast" ),
636
+ pd .DataFrame ,
495
637
),
496
638
pd .DataFrame ,
497
639
)
0 commit comments