Skip to content

Commit 603e614

Browse files
Fix value_counts and apply. (#401)
* Series.value_counts returns Series[int]. * Series.apply callable result might not be hashable It's possible for the callable in Series.apply to return something non-hashable like a list, but the result of apply should still be a Series. * More detailed typing for DataFrame.apply. Whether it returns a Series or a DataFrame depends on the return type of the callable. In the case of the callable returning a scalar, the result is a Series unless the result_type is "broadcast". * Add test for #393.
1 parent 581c1eb commit 603e614

File tree

4 files changed

+53
-11
lines changed

4 files changed

+53
-11
lines changed

pandas-stubs/core/frame.pyi

+22-4
Original file line numberDiff line numberDiff line change
@@ -1088,14 +1088,32 @@ class DataFrame(NDFrame, OpsMixin):
10881088
**kwargs,
10891089
) -> DataFrame: ...
10901090
@overload
1091-
def apply(self, f: Callable) -> Series: ...
1091+
def apply(
1092+
self,
1093+
f: Callable[..., Series],
1094+
axis: AxisType = ...,
1095+
raw: _bool = ...,
1096+
result_type: Literal["expand", "reduce", "broadcast"] | None = ...,
1097+
args=...,
1098+
**kwargs,
1099+
) -> DataFrame: ...
1100+
@overload
1101+
def apply(
1102+
self,
1103+
f: Callable[..., Scalar],
1104+
axis: AxisType = ...,
1105+
raw: _bool = ...,
1106+
result_type: Literal["expand", "reduce"] | None = ...,
1107+
args=...,
1108+
**kwargs,
1109+
) -> Series: ...
10921110
@overload
10931111
def apply(
10941112
self,
1095-
f: Callable,
1096-
axis: AxisType,
1113+
f: Callable[..., Scalar],
1114+
result_type: Literal["broadcast"],
1115+
axis: AxisType = ...,
10971116
raw: _bool = ...,
1098-
result_type: _str | None = ...,
10991117
args=...,
11001118
**kwargs,
11011119
) -> DataFrame: ...

pandas-stubs/core/series.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
685685
@overload
686686
def apply(
687687
self,
688-
func: Callable[..., Hashable],
688+
func: Callable[..., Scalar | Sequence | Mapping],
689689
convertDType: _bool = ...,
690690
args: tuple = ...,
691691
**kwds,
@@ -1193,7 +1193,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
11931193
ascending: _bool = ...,
11941194
bins: int | None = ...,
11951195
dropna: _bool = ...,
1196-
) -> Series[S1]: ...
1196+
) -> Series[int]: ...
11971197
def transpose(self, *args, **kwargs) -> Series[S1]: ...
11981198
@property
11991199
def T(self) -> Series[S1]: ...

tests/test_frame.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,28 @@ def test_types_unique() -> None:
460460

461461
def test_types_apply() -> None:
462462
df = pd.DataFrame(data={"col1": [2, 1], "col2": [3, 4]})
463-
df.apply(lambda x: x**2)
464-
df.apply(np.exp)
465-
df.apply(str)
463+
464+
def returns_series(x: pd.Series) -> pd.Series:
465+
return x**2
466+
467+
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
468+
469+
def returns_scalar(x: pd.Series) -> float:
470+
return 2
471+
472+
check(assert_type(df.apply(returns_scalar), pd.Series), pd.Series)
473+
check(
474+
assert_type(df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame),
475+
pd.DataFrame,
476+
)
477+
check(assert_type(df.apply(np.exp), pd.DataFrame), pd.DataFrame)
478+
check(assert_type(df.apply(str), pd.Series), pd.Series)
479+
480+
# GH 393
481+
def gethead(s: pd.Series, y: int) -> pd.Series:
482+
return s.head(y)
483+
484+
check(assert_type(df.apply(gethead, args=(4,)), pd.DataFrame), pd.DataFrame)
466485

467486

468487
def test_types_applymap() -> None:

tests/test_series.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,8 @@ def test_types_idxmax() -> None:
367367

368368

369369
def test_types_value_counts() -> None:
370-
s = pd.Series([1, 2])
371-
s.value_counts()
370+
s = pd.Series(["a", "b"])
371+
check(assert_type(s.value_counts(), "pd.Series[int]"), pd.Series, int)
372372

373373

374374
def test_types_unique() -> None:
@@ -398,6 +398,11 @@ def retseries(x: float) -> float:
398398

399399
check(assert_type(s.apply(retseries).tolist(), list), list)
400400

401+
def retlist(x: float) -> list:
402+
return [x]
403+
404+
check(assert_type(s.apply(retlist), pd.Series), pd.Series, list)
405+
401406
def get_depth(url: str) -> int:
402407
return len(url)
403408

0 commit comments

Comments
 (0)