Skip to content

Commit 3c7df2f

Browse files
authored
Infer dtype of Series in more cases (#766)
* Infer dtype of Series in more cases * ignores * address some of the feedback * One of pandas's windows/linux inconsistencies * more changes; incl. np.intp/int64 -> np.integer * everything(?) but agg * return series
1 parent 82f8994 commit 3c7df2f

10 files changed

+286
-187
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
hooks:
1212
- id: isort
1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.0.283
14+
rev: v0.0.285
1515
hooks:
1616
- id: ruff
1717
args: [

pandas-stubs/_typing.pyi

+8-8
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ DtypeBackend: TypeAlias = Literal["pyarrow", "numpy_nullable"]
9191

9292
BooleanDtypeArg: TypeAlias = (
9393
# Builtin bool type and its string alias
94-
type[bool] # noqa: PYI030,PYI055
94+
type[bool] # noqa: PYI030
9595
| Literal["bool"]
9696
# Pandas nullable boolean type and its string alias
9797
| pd.BooleanDtype
@@ -105,7 +105,7 @@ BooleanDtypeArg: TypeAlias = (
105105
)
106106
IntDtypeArg: TypeAlias = (
107107
# Builtin integer type and its string alias
108-
type[int] # noqa: PYI030,PYI055
108+
type[int] # noqa: PYI030
109109
| Literal["int"]
110110
# Pandas nullable integer types and their string aliases
111111
| pd.Int8Dtype
@@ -137,7 +137,7 @@ IntDtypeArg: TypeAlias = (
137137
)
138138
UIntDtypeArg: TypeAlias = (
139139
# Pandas nullable unsigned integer types and their string aliases
140-
pd.UInt8Dtype # noqa: PYI030,PYI055
140+
pd.UInt8Dtype # noqa: PYI030
141141
| pd.UInt16Dtype
142142
| pd.UInt32Dtype
143143
| pd.UInt64Dtype
@@ -166,7 +166,7 @@ UIntDtypeArg: TypeAlias = (
166166
)
167167
FloatDtypeArg: TypeAlias = (
168168
# Builtin float type and its string alias
169-
type[float] # noqa: PYI030,PYI055
169+
type[float] # noqa: PYI030
170170
| Literal["float"]
171171
# Pandas nullable float types and their string aliases
172172
| pd.Float32Dtype
@@ -197,7 +197,7 @@ FloatDtypeArg: TypeAlias = (
197197
)
198198
ComplexDtypeArg: TypeAlias = (
199199
# Builtin complex type and its string alias
200-
type[complex] # noqa: PYI030,PYI055
200+
type[complex] # noqa: PYI030
201201
| Literal["complex"]
202202
# Numpy complex types and their aliases
203203
# https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.csingle
@@ -326,7 +326,7 @@ TimestampDtypeArg: TypeAlias = Literal[
326326

327327
StrDtypeArg: TypeAlias = (
328328
# Builtin str type and its string alias
329-
type[str] # noqa: PYI030,PYI055
329+
type[str] # noqa: PYI030
330330
| Literal["str"]
331331
# Pandas nullable string type and its string alias
332332
| pd.StringDtype
@@ -340,7 +340,7 @@ StrDtypeArg: TypeAlias = (
340340
)
341341
BytesDtypeArg: TypeAlias = (
342342
# Builtin bytes type and its string alias
343-
type[bytes] # noqa: PYI030,PYI055
343+
type[bytes] # noqa: PYI030
344344
| Literal["bytes"]
345345
# Numpy bytes type and its string alias
346346
# https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_
@@ -353,7 +353,7 @@ CategoryDtypeArg: TypeAlias = CategoricalDtype | Literal["category"]
353353

354354
ObjectDtypeArg: TypeAlias = (
355355
# Builtin object type and its string alias
356-
type[object] # noqa: PYI030,PYI055
356+
type[object] # noqa: PYI030
357357
| Literal["object"]
358358
# Numpy object type and its string alias
359359
# https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.object_

pandas-stubs/core/frame.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ class DataFrame(NDFrame, OpsMixin):
555555
@overload
556556
def __getitem__( # type: ignore[misc]
557557
self,
558-
key: Series[_bool]
558+
key: Series
559559
| DataFrame
560560
| Index
561561
| np_ndarray_str

pandas-stubs/core/indexes/period.pyi

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ import numpy as np
66
import pandas as pd
77
from pandas import Index
88
from pandas.core.indexes.accessors import PeriodIndexFieldOps
9-
from pandas.core.indexes.datetimelike import (
10-
DatetimeIndexOpsMixin as DatetimeIndexOpsMixin,
11-
)
9+
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
1210
from pandas.core.indexes.timedeltas import TimedeltaIndex
1311
from typing_extensions import Self
1412

pandas-stubs/core/series.pyi

+71-49
Original file line numberDiff line numberDiff line change
@@ -211,46 +211,57 @@ class _LocIndexerSeries(_LocIndexer, Generic[S1]):
211211
value: S1 | ArrayLike | Series[S1] | None,
212212
) -> None: ...
213213

214+
_ListLike: TypeAlias = (
215+
ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | IndexOpsMixin[S1]
216+
)
217+
214218
class Series(IndexOpsMixin[S1], NDFrame):
215-
_ListLike: TypeAlias = ArrayLike | dict[_str, np.ndarray] | list | tuple | Index
216219
__hash__: ClassVar[None]
217220

218-
# TODO: can __new__ be converted to __init__? Pandas implements __init__
219221
@overload
220-
def __new__(
222+
def __new__( # type: ignore[misc]
221223
cls,
222-
data: DatetimeIndex | Sequence[Timestamp | np.datetime64 | datetime],
224+
data: DatetimeIndex
225+
| Sequence[np.datetime64 | datetime]
226+
| np.datetime64
227+
| datetime,
223228
index: Axes | None = ...,
229+
*,
224230
dtype: TimestampDtypeArg = ...,
225-
name: Hashable | None = ...,
231+
name: Hashable = ...,
226232
copy: bool = ...,
227233
) -> TimestampSeries: ...
228234
@overload
229-
def __new__(
235+
def __new__( # type: ignore[misc]
230236
cls,
231237
data: _ListLike,
232238
index: Axes | None = ...,
233239
*,
234240
dtype: TimestampDtypeArg,
235-
name: Hashable | None = ...,
241+
name: Hashable = ...,
236242
copy: bool = ...,
237243
) -> TimestampSeries: ...
238244
@overload
239-
def __new__(
245+
def __new__( # type: ignore[misc]
240246
cls,
241247
data: PeriodIndex,
242248
index: Axes | None = ...,
249+
*,
243250
dtype: PeriodDtype = ...,
244-
name: Hashable | None = ...,
251+
name: Hashable = ...,
245252
copy: bool = ...,
246253
) -> PeriodSeries: ...
247254
@overload
248-
def __new__(
255+
def __new__( # type: ignore[misc]
249256
cls,
250-
data: TimedeltaIndex | Sequence[Timedelta | np.timedelta64 | timedelta],
257+
data: TimedeltaIndex
258+
| Sequence[np.timedelta64 | timedelta]
259+
| np.timedelta64
260+
| timedelta,
251261
index: Axes | None = ...,
262+
*,
252263
dtype: TimedeltaDtypeArg = ...,
253-
name: Hashable | None = ...,
264+
name: Hashable = ...,
254265
copy: bool = ...,
255266
) -> TimedeltaSeries: ...
256267
@overload
@@ -260,35 +271,39 @@ class Series(IndexOpsMixin[S1], NDFrame):
260271
| Interval[_OrderableT]
261272
| Sequence[Interval[_OrderableT]],
262273
index: Axes | None = ...,
274+
*,
263275
dtype: Literal["Interval"] = ...,
264-
name: Hashable | None = ...,
276+
name: Hashable = ...,
265277
copy: bool = ...,
266278
) -> IntervalSeries[_OrderableT]: ...
267279
@overload
268280
def __new__(
269281
cls,
270-
data: object | _ListLike | Series[S1] | dict[int, S1] | dict[_str, S1] | None,
271-
dtype: type[S1],
282+
data: Scalar | _ListLike | dict[int, Any] | dict[_str, Any] | None,
272283
index: Axes | None = ...,
273-
name: Hashable | None = ...,
284+
*,
285+
dtype: type[S1],
286+
name: Hashable = ...,
274287
copy: bool = ...,
275288
) -> Self: ...
276289
@overload
277290
def __new__(
278291
cls,
279-
data: Series[S1] | dict[int, S1] | dict[_str, S1] = ...,
292+
data: S1 | _ListLike[S1] | dict[int, S1] | dict[_str, S1],
280293
index: Axes | None = ...,
294+
*,
281295
dtype: Dtype = ...,
282-
name: Hashable | None = ...,
296+
name: Hashable = ...,
283297
copy: bool = ...,
284298
) -> Self: ...
285299
@overload
286300
def __new__(
287301
cls,
288-
data: object | _ListLike | None = ...,
302+
data: Scalar | _ListLike | dict[int, Any] | dict[_str, Any] | None = ...,
289303
index: Axes | None = ...,
304+
*,
290305
dtype: Dtype = ...,
291-
name: Hashable | None = ...,
306+
name: Hashable = ...,
292307
copy: bool = ...,
293308
) -> Series: ...
294309
@property
@@ -342,8 +357,8 @@ class Series(IndexOpsMixin[S1], NDFrame):
342357
| Series[S1]
343358
| slice
344359
| MaskType
345-
| tuple[S1 | slice, ...],
346-
) -> Series: ...
360+
| tuple[Hashable | slice, ...],
361+
) -> Self: ...
347362
@overload
348363
def __getitem__(self, idx: int | _str) -> S1: ...
349364
def __setitem__(self, key, value) -> None: ...
@@ -680,11 +695,13 @@ class Series(IndexOpsMixin[S1], NDFrame):
680695
@overload
681696
def dot(self, other: DataFrame) -> Series[S1]: ...
682697
@overload
683-
def dot(self, other: _ListLike) -> np.ndarray: ...
698+
def dot(
699+
self, other: ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | Index[S1]
700+
) -> np.ndarray: ...
684701
def __matmul__(self, other): ...
685702
def __rmatmul__(self, other): ...
686703
@overload
687-
def searchsorted(
704+
def searchsorted( # type: ignore[misc]
688705
self,
689706
value: _ListLike,
690707
side: Literal["left", "right"] = ...,
@@ -781,7 +798,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
781798
ignore_index: _bool = ...,
782799
inplace: Literal[False] = ...,
783800
key: Callable | None = ...,
784-
) -> Series: ...
801+
) -> Self: ...
785802
@overload
786803
def sort_index(
787804
self,
@@ -820,6 +837,14 @@ class Series(IndexOpsMixin[S1], NDFrame):
820837
) -> DataFrame: ...
821838
def map(self, arg, na_action: Literal["ignore"] | None = ...) -> Series[S1]: ...
822839
@overload
840+
def aggregate( # type: ignore[misc]
841+
self: Series[int],
842+
func: Literal["mean"],
843+
axis: AxisIndex = ...,
844+
*args,
845+
**kwargs,
846+
) -> float: ...
847+
@overload
823848
def aggregate(
824849
self,
825850
func: AggFuncTypeBase,
@@ -834,7 +859,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
834859
axis: AxisIndex = ...,
835860
*args,
836861
**kwargs,
837-
) -> Series[S1]: ...
862+
) -> Series: ...
838863
agg = aggregate
839864
@overload
840865
def transform(
@@ -902,7 +927,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
902927
inplace: Literal[False] = ...,
903928
level: Level | None = ...,
904929
errors: IgnoreRaise = ...,
905-
) -> Series: ...
930+
) -> Self: ...
906931
@overload
907932
def rename(
908933
self,
@@ -913,7 +938,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
913938
inplace: Literal[False] = ...,
914939
level: Level | None = ...,
915940
errors: IgnoreRaise = ...,
916-
) -> Series: ...
941+
) -> Self: ...
917942
@overload
918943
def rename(
919944
self,
@@ -932,7 +957,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
932957
copy: _bool = ...,
933958
limit: int | None = ...,
934959
tolerance: float | None = ...,
935-
) -> Series: ...
960+
) -> Self: ...
936961
@overload
937962
def drop(
938963
self,
@@ -956,7 +981,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
956981
level: Level | None = ...,
957982
inplace: Literal[False] = ...,
958983
errors: IgnoreRaise = ...,
959-
) -> Series: ...
984+
) -> Self: ...
960985
@overload
961986
def drop(
962987
self,
@@ -1344,7 +1369,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
13441369
na_option: Literal["keep", "top", "bottom"] = ...,
13451370
ascending: _bool = ...,
13461371
pct: _bool = ...,
1347-
) -> Series: ...
1372+
) -> Series[float]: ...
13481373
def where(
13491374
self,
13501375
cond: Series[S1]
@@ -1431,14 +1456,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
14311456
# just failed to generate these so I couldn't match
14321457
# them up.
14331458
@overload
1434-
def __add__(self, other: TimestampSeries) -> TimestampSeries: ...
1435-
@overload
1436-
def __add__(self, other: DatetimeIndex) -> TimestampSeries: ...
1437-
@overload
1438-
def __add__(self, other: Timestamp) -> TimestampSeries: ...
1459+
def __add__(self, other: S1 | Self) -> Self: ...
14391460
@overload
14401461
def __add__(
1441-
self, other: num | _str | Timedelta | _ListLike | Series[S1] | np.timedelta64
1462+
self, other: num | _str | Timedelta | _ListLike | Series | np.timedelta64
14421463
) -> Series: ...
14431464
# ignore needed for mypy as we want different results based on the arguments
14441465
@overload # type: ignore[override]
@@ -1479,7 +1500,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
14791500
) -> Series[bool]: ...
14801501
@overload
14811502
def __or__(self, other: int | np_ndarray_anyint | Series[int]) -> Series[int]: ...
1482-
def __radd__(self, other: num | _str | _ListLike | Series[S1]) -> Series[S1]: ...
1503+
@overload
1504+
def __radd__(self, other: S1 | Series[S1]) -> Self: ...
1505+
@overload
1506+
def __radd__(self, other: num | _str | _ListLike | Series) -> Series: ...
14831507
# ignore needed for mypy as we want different results based on the arguments
14841508
@overload # type: ignore[override]
14851509
def __rand__( # type: ignore[misc]
@@ -1961,11 +1985,16 @@ class Series(IndexOpsMixin[S1], NDFrame):
19611985
axis: AxisIndex | None = ...,
19621986
copy: _bool = ...,
19631987
inplace: Literal[False] = ...,
1964-
) -> Series: ...
1965-
def set_axis(
1966-
self, labels, *, axis: Axis = ..., copy: _bool = ...
1967-
) -> Series[S1]: ...
1988+
) -> Self: ...
1989+
def set_axis(self, labels, *, axis: Axis = ..., copy: _bool = ...) -> Self: ...
19681990
def __iter__(self) -> Iterator[S1]: ...
1991+
def xs(
1992+
self,
1993+
key: Hashable,
1994+
axis: AxisIndex = ...,
1995+
level: Level | None = ...,
1996+
drop_level: _bool = ...,
1997+
) -> Self: ...
19691998

19701999
class TimestampSeries(Series[Timestamp]):
19712000
# ignore needed because of mypy
@@ -2092,13 +2121,6 @@ class TimedeltaSeries(Series[Timedelta]):
20922121
numeric_only: _bool = ...,
20932122
**kwargs,
20942123
) -> Timedelta: ...
2095-
def xs(
2096-
self,
2097-
key: Hashable,
2098-
axis: AxisIndex = ...,
2099-
level: Level | None = ...,
2100-
drop_level: _bool = ...,
2101-
) -> Series: ...
21022124

21032125
class PeriodSeries(Series[Period]):
21042126
# ignore needed because of mypy

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ numpy = [
3939
]
4040

4141
[tool.poetry.dev-dependencies]
42-
mypy = "1.5.0"
42+
mypy = "1.5.1"
4343
pandas = "2.0.3"
4444
pyarrow = ">=10.0.1"
4545
pytest = ">=7.1.2"

0 commit comments

Comments
 (0)