Skip to content

Commit 55fca56

Browse files
authored
update pyright, numpy. fix issues 664 and 662 (#667)
1 parent 14ce816 commit 55fca56

File tree

7 files changed

+71
-34
lines changed

7 files changed

+71
-34
lines changed

pandas-stubs/core/arrays/interval.pyi

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from typing import overload
2+
13
import numpy as np
2-
from pandas import Index
4+
from pandas import (
5+
Index,
6+
Series,
7+
)
38
from pandas.core.arrays.base import ExtensionArray as ExtensionArray
49
from typing_extensions import Self
510

@@ -9,7 +14,9 @@ from pandas._libs.interval import (
914
)
1015
from pandas._typing import (
1116
Axis,
17+
Scalar,
1218
TakeIndexer,
19+
np_ndarray_bool,
1320
)
1421

1522
class IntervalArray(IntervalMixin, ExtensionArray):
@@ -70,5 +77,10 @@ class IntervalArray(IntervalMixin, ExtensionArray):
7077
def __arrow_array__(self, type=...): ...
7178
def to_tuples(self, na_tuple: bool = ...): ...
7279
def repeat(self, repeats, axis: Axis | None = ...): ...
73-
def contains(self, other): ...
80+
@overload
81+
def contains(self, other: Series) -> Series[bool]: ...
82+
@overload
83+
def contains(
84+
self, other: Scalar | ExtensionArray | Index | np.ndarray
85+
) -> np_ndarray_bool: ...
7486
def overlaps(self, other: Interval) -> bool: ...

pandas-stubs/core/indexes/interval.pyi

+7-7
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]):
168168
) -> IntervalIndex[Interval[pd.Timedelta]]: ...
169169
@overload
170170
@classmethod
171-
def from_tuples(
171+
def from_tuples( # pyright: ignore[reportOverlappingOverload]
172172
cls,
173173
data: Sequence[tuple[int, int]],
174174
closed: IntervalClosedType = ...,
@@ -217,7 +217,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]):
217217
) -> IntervalIndex[pd.Interval[pd.Timedelta]]: ...
218218
def to_tuples(self, na_tuple: bool = ...) -> pd.Index: ...
219219
@overload
220-
def __contains__(self, key: IntervalT) -> bool: ... # type: ignore[misc]
220+
def __contains__(self, key: IntervalT) -> bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
221221
@overload
222222
def __contains__(self, key: object) -> Literal[False]: ...
223223
def astype(self, dtype: DtypeArg, copy: bool = ...) -> IntervalIndex: ...
@@ -292,13 +292,13 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]):
292292
@overload
293293
def __lt__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ...
294294
@overload # type: ignore[override]
295-
def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc]
295+
def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
296296
@overload
297297
def __eq__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc]
298298
@overload
299299
def __eq__(self, other: object) -> Literal[False]: ...
300300
@overload # type: ignore[override]
301-
def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc]
301+
def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
302302
@overload
303303
def __ne__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc]
304304
@overload
@@ -307,7 +307,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]):
307307
# misc here because int and float overlap but interval has distinct types
308308
# int gets hit first and so the correct type is returned
309309
@overload
310-
def interval_range( # type: ignore[misc]
310+
def interval_range( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
311311
start: int = ...,
312312
end: int = ...,
313313
periods: int | None = ...,
@@ -318,7 +318,7 @@ def interval_range( # type: ignore[misc]
318318

319319
# Overlaps since int is a subclass of float
320320
@overload
321-
def interval_range( # pyright: reportOverlappingOverload=false
321+
def interval_range( # pyright: ignore[reportOverlappingOverload]
322322
start: int,
323323
*,
324324
end: None = ...,
@@ -328,7 +328,7 @@ def interval_range( # pyright: reportOverlappingOverload=false
328328
closed: IntervalClosedType = ...,
329329
) -> IntervalIndex[Interval[int]]: ...
330330
@overload
331-
def interval_range( # pyright: reportOverlappingOverload=false
331+
def interval_range( # pyright: ignore[reportOverlappingOverload]
332332
*,
333333
start: None = ...,
334334
end: int,

pandas-stubs/core/series.pyi

+23-9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ from pandas import (
3838
)
3939
from pandas.core.arrays.base import ExtensionArray
4040
from pandas.core.arrays.categorical import CategoricalAccessor
41+
from pandas.core.arrays.interval import IntervalArray
4142
from pandas.core.groupby.generic import (
4243
_SeriesGroupByNonScalar,
4344
_SeriesGroupByScalar,
@@ -76,7 +77,10 @@ from typing_extensions import (
7677
)
7778
import xarray as xr
7879

79-
from pandas._libs.interval import Interval
80+
from pandas._libs.interval import (
81+
Interval,
82+
_OrderableT,
83+
)
8084
from pandas._libs.missing import NAType
8185
from pandas._libs.tslibs import BaseOffset
8286
from pandas._typing import (
@@ -245,43 +249,49 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
245249
@overload
246250
def __new__(
247251
cls,
248-
data: IntervalIndex[Interval[int]],
252+
data: IntervalIndex[Interval[int]] | Interval[int] | Sequence[Interval[int]],
249253
index: Axes | None = ...,
250254
dtype=...,
251255
name: Hashable | None = ...,
252256
copy: bool = ...,
253257
fastpath: bool = ...,
254-
) -> Series[Interval[int]]: ...
258+
) -> IntervalSeries[int]: ...
255259
@overload
256260
def __new__(
257261
cls,
258-
data: IntervalIndex[Interval[float]],
262+
data: IntervalIndex[Interval[float]]
263+
| Interval[float]
264+
| Sequence[Interval[float]],
259265
index: Axes | None = ...,
260266
dtype=...,
261267
name: Hashable | None = ...,
262268
copy: bool = ...,
263269
fastpath: bool = ...,
264-
) -> Series[Interval[float]]: ...
270+
) -> IntervalSeries[float]: ...
265271
@overload
266272
def __new__(
267273
cls,
268-
data: IntervalIndex[Interval[Timestamp]],
274+
data: IntervalIndex[Interval[Timestamp]]
275+
| Interval[Timestamp]
276+
| Sequence[Interval[Timestamp]],
269277
index: Axes | None = ...,
270278
dtype=...,
271279
name: Hashable | None = ...,
272280
copy: bool = ...,
273281
fastpath: bool = ...,
274-
) -> Series[Interval[Timestamp]]: ...
282+
) -> IntervalSeries[Timestamp]: ...
275283
@overload
276284
def __new__(
277285
cls,
278-
data: IntervalIndex[Interval[Timedelta]],
286+
data: IntervalIndex[Interval[Timedelta]]
287+
| Interval[Timedelta]
288+
| Sequence[Interval[Timedelta]],
279289
index: Axes | None = ...,
280290
dtype=...,
281291
name: Hashable | None = ...,
282292
copy: bool = ...,
283293
fastpath: bool = ...,
284-
) -> Series[Interval[Timedelta]]: ...
294+
) -> IntervalSeries[Timedelta]: ...
285295
@overload
286296
def __new__(
287297
cls,
@@ -1997,3 +2007,7 @@ class OffsetSeries(Series):
19972007
def __radd__(self, other: Period) -> PeriodSeries: ...
19982008
@overload
19992009
def __radd__(self, other: BaseOffset) -> OffsetSeries: ...
2010+
2011+
class IntervalSeries(Series, Generic[_OrderableT]):
2012+
@property
2013+
def array(self) -> IntervalArray: ...

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ types-pytz = ">= 2022.1.1"
3838
mypy = "1.2.0"
3939
pyarrow = ">=10.0.1"
4040
pytest = ">=7.1.2"
41-
pyright = ">= 1.1.300"
41+
pyright = ">= 1.1.305"
4242
poethepoet = ">=0.16.5"
4343
loguru = ">=0.6.0"
44-
pandas = "2.0.0"
45-
numpy = ">=1.24.1"
44+
pandas = "2.0.1"
45+
numpy = ">=1.24.3"
4646
typing-extensions = ">=4.4.0"
4747
matplotlib = ">=3.5.1"
4848
pre-commit = ">=2.19.0"

tests/test_interval.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import numpy as np
4+
from numpy import typing as npt
35
import pandas as pd
46
from typing_extensions import assert_type
57

@@ -84,3 +86,12 @@ def test_interval_length() -> None:
8486
if TYPE_CHECKING_INVALID_USAGE:
8587
pd.Timestamp("2001-01-02") in i3 # type: ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
8688
i3 + pd.Timedelta(seconds=20) # type: ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
89+
90+
91+
def test_interval_array_contains():
92+
df = pd.DataFrame({"A": range(1, 10)})
93+
obj = pd.Interval(1, 4)
94+
ser = pd.Series(obj, index=df.index)
95+
arr = ser.array
96+
check(assert_type(arr.contains(df["A"]), "pd.Series[bool]"), pd.Series, np.bool_)
97+
check(assert_type(arr.contains(3), npt.NDArray[np.bool_]), np.ndarray)

tests/test_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def test_feather():
570570
with pytest_warns_bounded(
571571
FutureWarning,
572572
match="is_sparse is deprecated and will be removed in a future version",
573-
lower="2.0.00",
573+
lower="2.0.99",
574574
):
575575
check(assert_type(DF.to_feather(bio), None), type(None))
576576
bio.seek(0)

tests/test_scalars.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,8 @@ def test_timedelta_cmp() -> None:
880880
le = check(assert_type(c_dt_timedelta <= td, bool), bool)
881881
assert gt != le
882882

883-
gt_b = check(assert_type(c_timedelta64 > td, Any), np.bool_)
884-
le_b = check(assert_type(c_timedelta64 <= td, Any), np.bool_)
883+
gt_b = check(assert_type(c_timedelta64 > td, Any), bool)
884+
le_b = check(assert_type(c_timedelta64 <= td, Any), bool)
885885
assert gt_b != le_b
886886

887887
gt_a = check(
@@ -948,8 +948,8 @@ def test_timedelta_cmp() -> None:
948948
ge = check(assert_type(c_dt_timedelta >= td, bool), bool)
949949
assert lt != ge
950950

951-
lt_b = check(assert_type(c_timedelta64 < td, Any), np.bool_)
952-
ge_b = check(assert_type(c_timedelta64 >= td, Any), np.bool_)
951+
lt_b = check(assert_type(c_timedelta64 < td, Any), bool)
952+
ge_b = check(assert_type(c_timedelta64 >= td, Any), bool)
953953
assert lt_b != ge_b
954954

955955
lt_a = check(
@@ -1038,8 +1038,8 @@ def test_timedelta_cmp_rhs() -> None:
10381038
ne = check(assert_type(c_dt_timedelta != td, bool), bool)
10391039
assert eq != ne
10401040

1041-
eq = check(assert_type(c_timedelta64 == td, Any), np.bool_)
1042-
ne = check(assert_type(c_timedelta64 != td, Any), np.bool_)
1041+
eq = check(assert_type(c_timedelta64 == td, Any), bool)
1042+
ne = check(assert_type(c_timedelta64 != td, Any), bool)
10431043
assert eq != ne
10441044

10451045
eq_a = check(assert_type(c_ndarray_td64 == td, Any), np.ndarray, np.bool_)
@@ -1270,8 +1270,8 @@ def test_timestamp_cmp() -> None:
12701270
check(assert_type(ts > c_series_dt64, "pd.Series[bool]"), pd.Series, np.bool_)
12711271
check(assert_type(ts <= c_series_dt64, "pd.Series[bool]"), pd.Series, np.bool_)
12721272

1273-
check(assert_type(c_np_dt64 > ts, Any), np.bool_)
1274-
check(assert_type(c_np_dt64 <= ts, Any), np.bool_)
1273+
check(assert_type(c_np_dt64 > ts, Any), bool)
1274+
check(assert_type(c_np_dt64 <= ts, Any), bool)
12751275

12761276
gt = check(assert_type(c_dt_datetime > ts, bool), bool)
12771277
lte = check(assert_type(c_dt_datetime <= ts, bool), bool)
@@ -1314,8 +1314,8 @@ def test_timestamp_cmp() -> None:
13141314
lt = check(assert_type(c_dt_datetime < ts, bool), bool)
13151315
assert gte != lt
13161316

1317-
check(assert_type(c_np_dt64 >= ts, Any), np.bool_)
1318-
check(assert_type(c_np_dt64 < ts, Any), np.bool_)
1317+
check(assert_type(c_np_dt64 >= ts, Any), bool)
1318+
check(assert_type(c_np_dt64 < ts, Any), bool)
13191319

13201320
check(assert_type(c_datetimeindex >= ts, np_ndarray_bool), np.ndarray, np.bool_)
13211321
check(assert_type(c_datetimeindex < ts, np_ndarray_bool), np.ndarray, np.bool_)
@@ -1388,8 +1388,8 @@ def test_timestamp_eq_ne_rhs() -> None:
13881388
[1, 2, 3], dtype="datetime64[ns]"
13891389
)
13901390

1391-
eq_a = check(assert_type(c_np_dt64 == ts, Any), np.bool_)
1392-
ne_a = check(assert_type(c_np_dt64 != ts, Any), np.bool_)
1391+
eq_a = check(assert_type(c_np_dt64 == ts, Any), bool)
1392+
ne_a = check(assert_type(c_np_dt64 != ts, Any), bool)
13931393
assert eq_a != ne_a
13941394

13951395
eq = check(assert_type(c_dt_datetime == ts, bool), bool)

0 commit comments

Comments
 (0)