Skip to content

Commit 117e97a

Browse files
authored
Type args and kwargs in pipe (#823)
* Test dataframe pipe typing * Test series pipe typing * Remove pipe annotations from DataFrame * Type args and kwargs parameters in generic pipe
1 parent a370cab commit 117e97a

File tree

5 files changed

+239
-15
lines changed

5 files changed

+239
-15
lines changed

pandas-stubs/_typing.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame
2828
from pandas.core.groupby.grouper import Grouper
2929
from pandas.core.indexes.base import Index
3030
from pandas.core.series import Series
31-
from typing_extensions import TypeAlias
31+
from typing_extensions import (
32+
ParamSpec,
33+
TypeAlias,
34+
)
3235

3336
from pandas._libs.interval import Interval
3437
from pandas._libs.tslibs import (
@@ -447,6 +450,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict
447450
Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple
448451
Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label]
449452
T = TypeVar("T")
453+
P = ParamSpec("P")
450454
FuncType: TypeAlias = Callable[..., Any]
451455
F = TypeVar("F", bound=FuncType)
452456
HashableT = TypeVar("HashableT", bound=Hashable)

pandas-stubs/core/frame.pyi

-7
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ from pandas._typing import (
113113
StorageOptions,
114114
StrLike,
115115
Suffixes,
116-
T as TType,
117116
TimestampConvention,
118117
ValidationOptions,
119118
WriteBuffer,
@@ -1829,12 +1828,6 @@ class DataFrame(NDFrame, OpsMixin):
18291828
freq=...,
18301829
**kwargs,
18311830
) -> DataFrame: ...
1832-
def pipe(
1833-
self,
1834-
func: Callable[..., TType] | tuple[Callable[..., TType], _str],
1835-
*args,
1836-
**kwargs,
1837-
) -> TType: ...
18381831
def pop(self, item: _str) -> Series: ...
18391832
def pow(
18401833
self,

pandas-stubs/core/generic.pyi

+17-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ from pandas import Index
1919
import pandas.core.indexing as indexing
2020
from pandas.core.series import Series
2121
import sqlalchemy.engine
22-
from typing_extensions import Self
22+
from typing_extensions import (
23+
Concatenate,
24+
Self,
25+
)
2326

2427
from pandas._typing import (
2528
S1,
@@ -40,6 +43,7 @@ from pandas._typing import (
4043
IgnoreRaise,
4144
IndexLabel,
4245
Level,
46+
P,
4347
ReplaceMethod,
4448
SortKind,
4549
StorageOptions,
@@ -352,8 +356,19 @@ class NDFrame(indexing.IndexingMixin):
352356
) -> Self: ...
353357
def head(self, n: int = ...) -> Self: ...
354358
def tail(self, n: int = ...) -> Self: ...
359+
@overload
360+
def pipe(
361+
self,
362+
func: Callable[Concatenate[Self, P], T],
363+
*args: P.args,
364+
**kwargs: P.kwargs,
365+
) -> T: ...
366+
@overload
355367
def pipe(
356-
self, func: Callable[..., T] | tuple[Callable[..., T], str], *args, **kwargs
368+
self,
369+
func: tuple[Callable[..., T], str],
370+
*args: Any,
371+
**kwargs: Any,
357372
) -> T: ...
358373
def __finalize__(self, other, method=..., **kwargs) -> Self: ...
359374
def __setattr__(self, name: _str, value) -> None: ...

tests/test_frame.py

+107-5
Original file line numberDiff line numberDiff line change
@@ -1436,21 +1436,123 @@ def foo(df: pd.DataFrame) -> pd.DataFrame:
14361436
.pipe(foo)
14371437
)
14381438

1439+
df = pd.DataFrame({"a": [1], "b": [2]})
14391440
check(assert_type(val, pd.DataFrame), pd.DataFrame)
14401441

1441-
check(assert_type(pd.DataFrame({"a": [1]}).pipe(foo), pd.DataFrame), pd.DataFrame)
1442+
check(assert_type(df.pipe(foo), pd.DataFrame), pd.DataFrame)
14421443

14431444
def bar(val: Styler) -> Styler:
14441445
return val
14451446

1446-
check(
1447-
assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(bar), Styler), Styler
1448-
)
1447+
check(assert_type(df.style.pipe(bar), Styler), Styler)
14491448

14501449
def baz(val: Styler) -> str:
14511450
return val.to_latex()
14521451

1453-
check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str)
1452+
check(assert_type(df.style.pipe(baz), str), str)
1453+
1454+
def qux(
1455+
df: pd.DataFrame,
1456+
positional_only: int,
1457+
/,
1458+
argument_1: list[float],
1459+
argument_2: str,
1460+
*,
1461+
keyword_only: tuple[int, int],
1462+
) -> pd.DataFrame:
1463+
return pd.DataFrame(df)
1464+
1465+
check(
1466+
assert_type(
1467+
df.pipe(qux, 1, [1.0, 2.0], argument_2="hi", keyword_only=(1, 2)),
1468+
pd.DataFrame,
1469+
),
1470+
pd.DataFrame,
1471+
)
1472+
1473+
if TYPE_CHECKING_INVALID_USAGE:
1474+
df.pipe(
1475+
qux,
1476+
"a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
1477+
[1.0, 2.0],
1478+
argument_2="hi",
1479+
keyword_only=(1, 2),
1480+
)
1481+
df.pipe(
1482+
qux,
1483+
1,
1484+
[1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues]
1485+
argument_2="hi",
1486+
keyword_only=(1, 2),
1487+
)
1488+
df.pipe(
1489+
qux,
1490+
1,
1491+
[1.0, 2.0],
1492+
argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
1493+
keyword_only=(1, 2),
1494+
)
1495+
df.pipe(
1496+
qux,
1497+
1,
1498+
[1.0, 2.0],
1499+
argument_2="hi",
1500+
keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
1501+
)
1502+
df.pipe( # type: ignore[call-arg]
1503+
qux,
1504+
1,
1505+
[1.0, 2.0],
1506+
argument_3="hi", # pyright: ignore[reportGeneralTypeIssues]
1507+
keyword_only=(1, 2),
1508+
)
1509+
df.pipe( # type: ignore[misc]
1510+
qux,
1511+
1,
1512+
[1.0, 2.0],
1513+
11, # type: ignore[arg-type]
1514+
(1, 2), # pyright: ignore[reportGeneralTypeIssues]
1515+
)
1516+
df.pipe( # type: ignore[call-arg]
1517+
qux,
1518+
positional_only=1, # pyright: ignore[reportGeneralTypeIssues]
1519+
argument_1=[1.0, 2.0],
1520+
argument_2=11, # type: ignore[arg-type]
1521+
keyword_only=(1, 2),
1522+
)
1523+
1524+
def dataframe_not_first_arg(x: int, df: pd.DataFrame) -> pd.DataFrame:
1525+
return df
1526+
1527+
check(
1528+
assert_type(
1529+
df.pipe(
1530+
(
1531+
dataframe_not_first_arg,
1532+
"df",
1533+
),
1534+
1,
1535+
),
1536+
pd.DataFrame,
1537+
),
1538+
pd.DataFrame,
1539+
)
1540+
1541+
if TYPE_CHECKING_INVALID_USAGE:
1542+
df.pipe(
1543+
(
1544+
dataframe_not_first_arg, # type: ignore[arg-type]
1545+
1, # pyright: ignore[reportGeneralTypeIssues]
1546+
),
1547+
1,
1548+
)
1549+
df.pipe(
1550+
( # pyright: ignore[reportGeneralTypeIssues]
1551+
1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
1552+
"df",
1553+
),
1554+
1,
1555+
)
14541556

14551557

14561558
# set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html

tests/test_series.py

+110
Original file line numberDiff line numberDiff line change
@@ -2914,3 +2914,113 @@ def test_timedeltaseries_operators() -> None:
29142914
pd.Series,
29152915
pd.Timedelta,
29162916
)
2917+
2918+
2919+
def test_pipe() -> None:
2920+
ser = pd.Series(range(10))
2921+
2922+
def first_arg_series(
2923+
ser: pd.Series,
2924+
positional_only: int,
2925+
/,
2926+
argument_1: list[float],
2927+
argument_2: str,
2928+
*,
2929+
keyword_only: tuple[int, int],
2930+
) -> pd.Series:
2931+
return ser
2932+
2933+
check(
2934+
assert_type(
2935+
ser.pipe(
2936+
first_arg_series,
2937+
1,
2938+
[1.0, 2.0],
2939+
argument_2="hi",
2940+
keyword_only=(1, 2),
2941+
),
2942+
pd.Series,
2943+
),
2944+
pd.Series,
2945+
)
2946+
2947+
if TYPE_CHECKING_INVALID_USAGE:
2948+
ser.pipe(
2949+
first_arg_series,
2950+
"a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
2951+
[1.0, 2.0],
2952+
argument_2="hi",
2953+
keyword_only=(1, 2),
2954+
)
2955+
ser.pipe(
2956+
first_arg_series,
2957+
1,
2958+
[1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues]
2959+
argument_2="hi",
2960+
keyword_only=(1, 2),
2961+
)
2962+
ser.pipe(
2963+
first_arg_series,
2964+
1,
2965+
[1.0, 2.0],
2966+
argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
2967+
keyword_only=(1, 2),
2968+
)
2969+
ser.pipe(
2970+
first_arg_series,
2971+
1,
2972+
[1.0, 2.0],
2973+
argument_2="hi",
2974+
keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
2975+
)
2976+
ser.pipe( # type: ignore[call-arg]
2977+
first_arg_series,
2978+
1,
2979+
[1.0, 2.0],
2980+
argument_3="hi", # pyright: ignore[reportGeneralTypeIssues]
2981+
keyword_only=(1, 2),
2982+
)
2983+
ser.pipe( # type: ignore[misc]
2984+
first_arg_series,
2985+
1,
2986+
[1.0, 2.0],
2987+
11, # type: ignore[arg-type]
2988+
(1, 2), # pyright: ignore[reportGeneralTypeIssues]
2989+
)
2990+
ser.pipe( # type: ignore[call-arg]
2991+
first_arg_series,
2992+
positional_only=1, # pyright: ignore[reportGeneralTypeIssues]
2993+
argument_1=[1.0, 2.0],
2994+
argument_2=11, # type: ignore[arg-type]
2995+
keyword_only=(1, 2),
2996+
)
2997+
2998+
def first_arg_not_series(argument_1: int, ser: pd.Series) -> pd.Series:
2999+
return ser
3000+
3001+
check(
3002+
assert_type(
3003+
ser.pipe(
3004+
(first_arg_not_series, "ser"),
3005+
1,
3006+
),
3007+
pd.Series,
3008+
),
3009+
pd.Series,
3010+
)
3011+
3012+
if TYPE_CHECKING_INVALID_USAGE:
3013+
ser.pipe(
3014+
(
3015+
first_arg_not_series, # type: ignore[arg-type]
3016+
1, # pyright: ignore[reportGeneralTypeIssues]
3017+
),
3018+
1,
3019+
)
3020+
ser.pipe(
3021+
(
3022+
1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
3023+
"df",
3024+
),
3025+
1,
3026+
)

0 commit comments

Comments
 (0)