Skip to content

Commit 0cf1cb9

Browse files
mutricylLaurent Mutricy
and
Laurent Mutricy
authored
887 select_dtypes stubs fixing (pandas-dev#900)
* adding tests for select_dtypes * adding tests for select_dtypes, resolves pandas-dev#887 * better handling of exclusions in arguments --------- Co-authored-by: Laurent Mutricy <[email protected]>
1 parent 0342d90 commit 0cf1cb9

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

pandas-stubs/core/frame.pyi

+45-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ from pandas.core.window.rolling import (
5050
Rolling,
5151
Window,
5252
)
53-
from typing_extensions import Self
53+
from typing_extensions import (
54+
Never,
55+
Self,
56+
TypeAlias,
57+
)
5458
import xarray as xr
5559

5660
from pandas._libs.lib import NoDefault
@@ -112,6 +116,7 @@ from pandas._typing import (
112116
SortKind,
113117
StataDateFormat,
114118
StorageOptions,
119+
StrDtypeArg,
115120
StrLike,
116121
Suffixes,
117122
T as _T,
@@ -608,10 +613,47 @@ class DataFrame(NDFrame, OpsMixin):
608613
self, expr: _str, *, inplace: Literal[False] = ..., **kwargs
609614
) -> DataFrame: ...
610615
def eval(self, expr: _str, *, inplace: _bool = ..., **kwargs): ...
616+
AstypeArgExt: TypeAlias = (
617+
AstypeArg
618+
| Literal[
619+
"number",
620+
"datetime64",
621+
"datetime",
622+
"timedelta",
623+
"timedelta64",
624+
"datetimetz",
625+
"datetime64[ns]",
626+
]
627+
)
628+
AstypeArgExtList: TypeAlias = AstypeArgExt | list[AstypeArgExt]
629+
@overload
630+
def select_dtypes(
631+
self, include: StrDtypeArg, exclude: AstypeArgExtList | None = ...
632+
) -> Never: ...
633+
@overload
634+
def select_dtypes(
635+
self, include: AstypeArgExtList | None, exclude: StrDtypeArg
636+
) -> Never: ...
637+
@overload
638+
def select_dtypes(self, exclude: StrDtypeArg) -> Never: ...
639+
@overload
640+
def select_dtypes(self, include: list[Never], exclude: list[Never]) -> Never: ...
641+
@overload
642+
def select_dtypes(
643+
self,
644+
include: AstypeArgExtList,
645+
exclude: AstypeArgExtList | None = ...,
646+
) -> DataFrame: ...
647+
@overload
648+
def select_dtypes(
649+
self,
650+
include: AstypeArgExtList | None,
651+
exclude: AstypeArgExtList,
652+
) -> DataFrame: ...
653+
@overload
611654
def select_dtypes(
612655
self,
613-
include: _str | list[_str] | None = ...,
614-
exclude: _str | list[_str] | None = ...,
656+
exclude: AstypeArgExtList,
615657
) -> DataFrame: ...
616658
def insert(
617659
self,

tests/test_frame.py

+48
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import pytest
3939
from typing_extensions import (
4040
TypeAlias,
41+
assert_never,
4142
assert_type,
4243
)
4344
import xarray as xr
@@ -3154,6 +3155,53 @@ def test_convert_dtypes_dtype_backend() -> None:
31543155
check(assert_type(dfn, pd.DataFrame), pd.DataFrame)
31553156

31563157

3158+
def test_select_dtypes() -> None:
3159+
df = pd.DataFrame({"a": [1, 2] * 3, "b": [True, False] * 3, "c": [1.0, 2.0] * 3})
3160+
check(assert_type(df.select_dtypes("number"), pd.DataFrame), pd.DataFrame)
3161+
check(assert_type(df.select_dtypes(np.number), pd.DataFrame), pd.DataFrame)
3162+
check(assert_type(df.select_dtypes(object), pd.DataFrame), pd.DataFrame)
3163+
check(assert_type(df.select_dtypes(include="bool"), pd.DataFrame), pd.DataFrame)
3164+
check(
3165+
assert_type(df.select_dtypes(include=["float64"], exclude=None), pd.DataFrame),
3166+
pd.DataFrame,
3167+
)
3168+
check(
3169+
assert_type(df.select_dtypes(exclude=["int64"], include=None), pd.DataFrame),
3170+
pd.DataFrame,
3171+
)
3172+
check(
3173+
assert_type(df.select_dtypes(exclude=["int64", object]), pd.DataFrame),
3174+
pd.DataFrame,
3175+
)
3176+
check(
3177+
assert_type(
3178+
df.select_dtypes(
3179+
exclude=[
3180+
np.datetime64,
3181+
"datetime64",
3182+
"datetime",
3183+
np.timedelta64,
3184+
"timedelta",
3185+
"timedelta64",
3186+
"category",
3187+
"datetimetz",
3188+
"datetime64[ns]",
3189+
]
3190+
),
3191+
pd.DataFrame,
3192+
),
3193+
pd.DataFrame,
3194+
)
3195+
if TYPE_CHECKING_INVALID_USAGE:
3196+
# include and exclude shall not be both empty
3197+
assert_never(df.select_dtypes([], []))
3198+
assert_never(df.select_dtypes())
3199+
# str like dtypes are not allowed
3200+
assert_never(df.select_dtypes(str))
3201+
assert_never(df.select_dtypes(exclude=str))
3202+
assert_never(df.select_dtypes(None, str))
3203+
3204+
31573205
def test_to_json_mode() -> None:
31583206
df = pd.DataFrame(
31593207
[["a", "b"], ["c", "d"]],

0 commit comments

Comments
 (0)