From 02abc69c77ccc54d98a799fc3cafd0b6e2a695dd Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 29 Apr 2023 22:48:16 -0400 Subject: [PATCH] update pyright, numpy. fix issues 664 and 662 --- pandas-stubs/core/arrays/interval.pyi | 16 +++++++++++-- pandas-stubs/core/indexes/interval.pyi | 14 +++++------ pandas-stubs/core/series.pyi | 32 ++++++++++++++++++-------- pyproject.toml | 6 ++--- tests/test_interval.py | 11 +++++++++ tests/test_io.py | 2 +- tests/test_scalars.py | 24 +++++++++---------- 7 files changed, 71 insertions(+), 34 deletions(-) diff --git a/pandas-stubs/core/arrays/interval.pyi b/pandas-stubs/core/arrays/interval.pyi index 55052733d..eefa791ee 100644 --- a/pandas-stubs/core/arrays/interval.pyi +++ b/pandas-stubs/core/arrays/interval.pyi @@ -1,5 +1,10 @@ +from typing import overload + import numpy as np -from pandas import Index +from pandas import ( + Index, + Series, +) from pandas.core.arrays.base import ExtensionArray as ExtensionArray from typing_extensions import Self @@ -9,7 +14,9 @@ from pandas._libs.interval import ( ) from pandas._typing import ( Axis, + Scalar, TakeIndexer, + np_ndarray_bool, ) class IntervalArray(IntervalMixin, ExtensionArray): @@ -70,5 +77,10 @@ class IntervalArray(IntervalMixin, ExtensionArray): def __arrow_array__(self, type=...): ... def to_tuples(self, na_tuple: bool = ...): ... def repeat(self, repeats, axis: Axis | None = ...): ... - def contains(self, other): ... + @overload + def contains(self, other: Series) -> Series[bool]: ... + @overload + def contains( + self, other: Scalar | ExtensionArray | Index | np.ndarray + ) -> np_ndarray_bool: ... def overlaps(self, other: Interval) -> bool: ... diff --git a/pandas-stubs/core/indexes/interval.pyi b/pandas-stubs/core/indexes/interval.pyi index 3ed0de84a..c1474174b 100644 --- a/pandas-stubs/core/indexes/interval.pyi +++ b/pandas-stubs/core/indexes/interval.pyi @@ -168,7 +168,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]): ) -> IntervalIndex[Interval[pd.Timedelta]]: ... @overload @classmethod - def from_tuples( + def from_tuples( # pyright: ignore[reportOverlappingOverload] cls, data: Sequence[tuple[int, int]], closed: IntervalClosedType = ..., @@ -217,7 +217,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]): ) -> IntervalIndex[pd.Interval[pd.Timedelta]]: ... def to_tuples(self, na_tuple: bool = ...) -> pd.Index: ... @overload - def __contains__(self, key: IntervalT) -> bool: ... # type: ignore[misc] + def __contains__(self, key: IntervalT) -> bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] @overload def __contains__(self, key: object) -> Literal[False]: ... def astype(self, dtype: DtypeArg, copy: bool = ...) -> IntervalIndex: ... @@ -292,13 +292,13 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]): @overload def __lt__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... @overload # type: ignore[override] - def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] + def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] @overload def __eq__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc] @overload def __eq__(self, other: object) -> Literal[False]: ... @overload # type: ignore[override] - def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] + def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] @overload def __ne__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc] @overload @@ -307,7 +307,7 @@ class IntervalIndex(ExtensionIndex, IntervalMixin, Generic[IntervalT]): # misc here because int and float overlap but interval has distinct types # int gets hit first and so the correct type is returned @overload -def interval_range( # type: ignore[misc] +def interval_range( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] start: int = ..., end: int = ..., periods: int | None = ..., @@ -318,7 +318,7 @@ def interval_range( # type: ignore[misc] # Overlaps since int is a subclass of float @overload -def interval_range( # pyright: reportOverlappingOverload=false +def interval_range( # pyright: ignore[reportOverlappingOverload] start: int, *, end: None = ..., @@ -328,7 +328,7 @@ def interval_range( # pyright: reportOverlappingOverload=false closed: IntervalClosedType = ..., ) -> IntervalIndex[Interval[int]]: ... @overload -def interval_range( # pyright: reportOverlappingOverload=false +def interval_range( # pyright: ignore[reportOverlappingOverload] *, start: None = ..., end: int, diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 37ed3d005..c53c288bb 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -38,6 +38,7 @@ from pandas import ( ) from pandas.core.arrays.base import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor +from pandas.core.arrays.interval import IntervalArray from pandas.core.groupby.generic import ( _SeriesGroupByNonScalar, _SeriesGroupByScalar, @@ -76,7 +77,10 @@ from typing_extensions import ( ) import xarray as xr -from pandas._libs.interval import Interval +from pandas._libs.interval import ( + Interval, + _OrderableT, +) from pandas._libs.missing import NAType from pandas._libs.tslibs import BaseOffset from pandas._typing import ( @@ -245,43 +249,49 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): @overload def __new__( cls, - data: IntervalIndex[Interval[int]], + data: IntervalIndex[Interval[int]] | Interval[int] | Sequence[Interval[int]], index: Axes | None = ..., dtype=..., name: Hashable | None = ..., copy: bool = ..., fastpath: bool = ..., - ) -> Series[Interval[int]]: ... + ) -> IntervalSeries[int]: ... @overload def __new__( cls, - data: IntervalIndex[Interval[float]], + data: IntervalIndex[Interval[float]] + | Interval[float] + | Sequence[Interval[float]], index: Axes | None = ..., dtype=..., name: Hashable | None = ..., copy: bool = ..., fastpath: bool = ..., - ) -> Series[Interval[float]]: ... + ) -> IntervalSeries[float]: ... @overload def __new__( cls, - data: IntervalIndex[Interval[Timestamp]], + data: IntervalIndex[Interval[Timestamp]] + | Interval[Timestamp] + | Sequence[Interval[Timestamp]], index: Axes | None = ..., dtype=..., name: Hashable | None = ..., copy: bool = ..., fastpath: bool = ..., - ) -> Series[Interval[Timestamp]]: ... + ) -> IntervalSeries[Timestamp]: ... @overload def __new__( cls, - data: IntervalIndex[Interval[Timedelta]], + data: IntervalIndex[Interval[Timedelta]] + | Interval[Timedelta] + | Sequence[Interval[Timedelta]], index: Axes | None = ..., dtype=..., name: Hashable | None = ..., copy: bool = ..., fastpath: bool = ..., - ) -> Series[Interval[Timedelta]]: ... + ) -> IntervalSeries[Timedelta]: ... @overload def __new__( cls, @@ -1997,3 +2007,7 @@ class OffsetSeries(Series): def __radd__(self, other: Period) -> PeriodSeries: ... @overload def __radd__(self, other: BaseOffset) -> OffsetSeries: ... + +class IntervalSeries(Series, Generic[_OrderableT]): + @property + def array(self) -> IntervalArray: ... diff --git a/pyproject.toml b/pyproject.toml index e635f2582..58eeaf955 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ types-pytz = ">= 2022.1.1" mypy = "1.2.0" pyarrow = ">=10.0.1" pytest = ">=7.1.2" -pyright = ">= 1.1.300" +pyright = ">= 1.1.305" poethepoet = ">=0.16.5" loguru = ">=0.6.0" -pandas = "2.0.0" -numpy = ">=1.24.1" +pandas = "2.0.1" +numpy = ">=1.24.3" typing-extensions = ">=4.4.0" matplotlib = ">=3.5.1" pre-commit = ">=2.19.0" diff --git a/tests/test_interval.py b/tests/test_interval.py index 52ac83fb8..7bdc8f4b1 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -1,5 +1,7 @@ from __future__ import annotations +import numpy as np +from numpy import typing as npt import pandas as pd from typing_extensions import assert_type @@ -84,3 +86,12 @@ def test_interval_length() -> None: if TYPE_CHECKING_INVALID_USAGE: pd.Timestamp("2001-01-02") in i3 # type: ignore[operator] # pyright: ignore[reportGeneralTypeIssues] i3 + pd.Timedelta(seconds=20) # type: ignore[operator] # pyright: ignore[reportGeneralTypeIssues] + + +def test_interval_array_contains(): + df = pd.DataFrame({"A": range(1, 10)}) + obj = pd.Interval(1, 4) + ser = pd.Series(obj, index=df.index) + arr = ser.array + check(assert_type(arr.contains(df["A"]), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(arr.contains(3), npt.NDArray[np.bool_]), np.ndarray) diff --git a/tests/test_io.py b/tests/test_io.py index e17e2de25..654f8bb7d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -570,7 +570,7 @@ def test_feather(): with pytest_warns_bounded( FutureWarning, match="is_sparse is deprecated and will be removed in a future version", - lower="2.0.00", + lower="2.0.99", ): check(assert_type(DF.to_feather(bio), None), type(None)) bio.seek(0) diff --git a/tests/test_scalars.py b/tests/test_scalars.py index 5d99902da..510536502 100644 --- a/tests/test_scalars.py +++ b/tests/test_scalars.py @@ -880,8 +880,8 @@ def test_timedelta_cmp() -> None: le = check(assert_type(c_dt_timedelta <= td, bool), bool) assert gt != le - gt_b = check(assert_type(c_timedelta64 > td, Any), np.bool_) - le_b = check(assert_type(c_timedelta64 <= td, Any), np.bool_) + gt_b = check(assert_type(c_timedelta64 > td, Any), bool) + le_b = check(assert_type(c_timedelta64 <= td, Any), bool) assert gt_b != le_b gt_a = check( @@ -948,8 +948,8 @@ def test_timedelta_cmp() -> None: ge = check(assert_type(c_dt_timedelta >= td, bool), bool) assert lt != ge - lt_b = check(assert_type(c_timedelta64 < td, Any), np.bool_) - ge_b = check(assert_type(c_timedelta64 >= td, Any), np.bool_) + lt_b = check(assert_type(c_timedelta64 < td, Any), bool) + ge_b = check(assert_type(c_timedelta64 >= td, Any), bool) assert lt_b != ge_b lt_a = check( @@ -1038,8 +1038,8 @@ def test_timedelta_cmp_rhs() -> None: ne = check(assert_type(c_dt_timedelta != td, bool), bool) assert eq != ne - eq = check(assert_type(c_timedelta64 == td, Any), np.bool_) - ne = check(assert_type(c_timedelta64 != td, Any), np.bool_) + eq = check(assert_type(c_timedelta64 == td, Any), bool) + ne = check(assert_type(c_timedelta64 != td, Any), bool) assert eq != ne eq_a = check(assert_type(c_ndarray_td64 == td, Any), np.ndarray, np.bool_) @@ -1270,8 +1270,8 @@ def test_timestamp_cmp() -> None: check(assert_type(ts > c_series_dt64, "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(ts <= c_series_dt64, "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(c_np_dt64 > ts, Any), np.bool_) - check(assert_type(c_np_dt64 <= ts, Any), np.bool_) + check(assert_type(c_np_dt64 > ts, Any), bool) + check(assert_type(c_np_dt64 <= ts, Any), bool) gt = check(assert_type(c_dt_datetime > ts, bool), bool) lte = check(assert_type(c_dt_datetime <= ts, bool), bool) @@ -1314,8 +1314,8 @@ def test_timestamp_cmp() -> None: lt = check(assert_type(c_dt_datetime < ts, bool), bool) assert gte != lt - check(assert_type(c_np_dt64 >= ts, Any), np.bool_) - check(assert_type(c_np_dt64 < ts, Any), np.bool_) + check(assert_type(c_np_dt64 >= ts, Any), bool) + check(assert_type(c_np_dt64 < ts, Any), bool) check(assert_type(c_datetimeindex >= ts, np_ndarray_bool), np.ndarray, np.bool_) check(assert_type(c_datetimeindex < ts, np_ndarray_bool), np.ndarray, np.bool_) @@ -1388,8 +1388,8 @@ def test_timestamp_eq_ne_rhs() -> None: [1, 2, 3], dtype="datetime64[ns]" ) - eq_a = check(assert_type(c_np_dt64 == ts, Any), np.bool_) - ne_a = check(assert_type(c_np_dt64 != ts, Any), np.bool_) + eq_a = check(assert_type(c_np_dt64 == ts, Any), bool) + ne_a = check(assert_type(c_np_dt64 != ts, Any), bool) assert eq_a != ne_a eq = check(assert_type(c_dt_datetime == ts, bool), bool)