File tree 2 files changed +38
-7
lines changed
2 files changed +38
-7
lines changed Original file line number Diff line number Diff line change @@ -1088,14 +1088,32 @@ class DataFrame(NDFrame, OpsMixin):
1088
1088
** kwargs ,
1089
1089
) -> DataFrame : ...
1090
1090
@overload
1091
- def apply (self , f : Callable ) -> Series : ...
1091
+ def apply (
1092
+ self ,
1093
+ f : Callable [..., Series ],
1094
+ axis : AxisType = ...,
1095
+ raw : _bool = ...,
1096
+ result_type : Literal ["expand" , "reduce" , "broadcast" ] | None = ...,
1097
+ args = ...,
1098
+ ** kwargs ,
1099
+ ) -> DataFrame : ...
1100
+ @overload
1101
+ def apply (
1102
+ self ,
1103
+ f : Callable [..., Scalar ],
1104
+ axis : AxisType = ...,
1105
+ raw : _bool = ...,
1106
+ result_type : Literal ["expand" , "reduce" ] | None = ...,
1107
+ args = ...,
1108
+ ** kwargs ,
1109
+ ) -> Series : ...
1092
1110
@overload
1093
1111
def apply (
1094
1112
self ,
1095
- f : Callable ,
1096
- axis : AxisType ,
1113
+ f : Callable [..., Scalar ],
1114
+ result_type : Literal ["broadcast" ],
1115
+ axis : AxisType = ...,
1097
1116
raw : _bool = ...,
1098
- result_type : _str | None = ...,
1099
1117
args = ...,
1100
1118
** kwargs ,
1101
1119
) -> DataFrame : ...
Original file line number Diff line number Diff line change @@ -460,9 +460,22 @@ def test_types_unique() -> None:
460
460
461
461
def test_types_apply () -> None :
462
462
df = pd .DataFrame (data = {"col1" : [2 , 1 ], "col2" : [3 , 4 ]})
463
- df .apply (lambda x : x ** 2 )
464
- df .apply (np .exp )
465
- df .apply (str )
463
+
464
+ def returns_series (x : pd .Series ) -> pd .Series :
465
+ return x ** 2
466
+
467
+ check (assert_type (df .apply (returns_series ), pd .DataFrame ), pd .DataFrame )
468
+
469
+ def returns_scalar (x : pd .Series ) -> float :
470
+ return 2
471
+
472
+ check (assert_type (df .apply (returns_scalar ), pd .Series ), pd .Series )
473
+ check (
474
+ assert_type (df .apply (returns_scalar , result_type = "broadcast" ), pd .DataFrame ),
475
+ pd .DataFrame ,
476
+ )
477
+ check (assert_type (df .apply (np .exp ), pd .DataFrame ), pd .DataFrame )
478
+ check (assert_type (df .apply (str ), pd .Series ), pd .Series )
466
479
467
480
468
481
def test_types_applymap () -> None :
You can’t perform that action at this time.
0 commit comments