Skip to content

Commit a54d276

Browse files
Terji PetersenTerji Petersen
Terji Petersen
authored and
Terji Petersen
committed
add typing to IntervalArray._left/_right, II
1 parent 45c0532 commit a54d276

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

pandas/_typing.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
from pandas.core.dtypes.dtypes import ExtensionDtype
4545

4646
from pandas import Interval
47+
from pandas.arrays import (
48+
DatetimeArray,
49+
TimedeltaArray,
50+
)
4751
from pandas.core.arrays.base import ExtensionArray
4852
from pandas.core.frame import DataFrame
4953
from pandas.core.generic import NDFrame
@@ -88,6 +92,7 @@
8892

8993
ArrayLike = Union["ExtensionArray", np.ndarray]
9094
AnyArrayLike = Union[ArrayLike, "Index", "Series"]
95+
TimeArrayLike = Union["DatetimeArray", "TimedeltaArray"]
9196

9297
# scalars
9398

pandas/core/arrays/interval.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ScalarIndexer,
4040
SequenceIndexer,
4141
SortKind,
42+
TimeArrayLike,
4243
npt,
4344
)
4445
from pandas.compat.numpy import function as nv
@@ -78,6 +79,8 @@
7879
unique,
7980
value_counts,
8081
)
82+
from pandas.core.arrays.datetimes import DatetimeArray
83+
from pandas.core.arrays.timedeltas import TimedeltaArray
8184
from pandas.core.arrays.base import (
8285
ExtensionArray,
8386
_extension_array_shared_docs,
@@ -102,6 +105,7 @@
102105

103106

104107
IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray")
108+
IntervalSideT = Union[TimeArrayLike, np.ndarray]
105109
IntervalOrNA = Union[Interval, float]
106110

107111
_interval_shared_docs: dict[str, str] = {}
@@ -123,8 +127,8 @@
123127
Parameters
124128
----------
125129
data : array-like (1-dimensional)
126-
Array-like containing Interval objects from which to build the
127-
%(klass)s.
130+
Array-like (ndarray, :class:`DateTimeArray`, :class:`TimeDeltaArray`) containing
131+
Interval objects from which to build the %(klass)s.
128132
closed : {'left', 'right', 'both', 'neither'}, default 'right'
129133
Whether the intervals are closed on the left-side, right-side, both or
130134
neither.
@@ -213,8 +217,8 @@ def ndim(self) -> Literal[1]:
213217
return 1
214218

215219
# To make mypy recognize the fields
216-
_left: np.ndarray
217-
_right: np.ndarray
220+
_left: IntervalSideT
221+
_right: IntervalSideT
218222
_dtype: IntervalDtype
219223

220224
# ---------------------------------------------------------------------
@@ -232,8 +236,8 @@ def __new__(
232236
data = extract_array(data, extract_numpy=True)
233237

234238
if isinstance(data, cls):
235-
left: ArrayLike = data._left
236-
right: ArrayLike = data._right
239+
left: IntervalSideT = data._left
240+
right: IntervalSideT = data._right
237241
closed = closed or data.closed
238242
dtype = IntervalDtype(left.dtype, closed=closed)
239243
else:
@@ -276,8 +280,8 @@ def __new__(
276280
@classmethod
277281
def _simple_new(
278282
cls: type[IntervalArrayT],
279-
left,
280-
right,
283+
left: IntervalSideT,
284+
right: IntervalSideT,
281285
dtype: IntervalDtype,
282286
) -> IntervalArrayT:
283287
result = IntervalMixin.__new__(cls)
@@ -295,7 +299,7 @@ def _ensure_simple_new_inputs(
295299
closed: IntervalClosedType | None = None,
296300
copy: bool = False,
297301
dtype: Dtype | None = None,
298-
) -> tuple[ArrayLike, ArrayLike, IntervalDtype]:
302+
) -> tuple[IntervalSideT, IntervalSideT, IntervalDtype]:
299303
"""Ensure correctness of input parameters for cls._simple_new."""
300304
from pandas.core.indexes.base import ensure_index
301305

@@ -1574,9 +1578,11 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
15741578

15751579
if isinstance(self._left, np.ndarray):
15761580
np.putmask(self._left, mask, value_left)
1581+
assert isinstance(self._right, np.ndarray)
15771582
np.putmask(self._right, mask, value_right)
15781583
else:
15791584
self._left._putmask(mask, value_left)
1585+
assert not isinstance(self._right, np.ndarray)
15801586
self._right._putmask(mask, value_right)
15811587

15821588
def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
@@ -1604,9 +1610,11 @@ def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
16041610
def delete(self: IntervalArrayT, loc) -> IntervalArrayT:
16051611
if isinstance(self._left, np.ndarray):
16061612
new_left = np.delete(self._left, loc)
1613+
assert isinstance(self._right, np.ndarray)
16071614
new_right = np.delete(self._right, loc)
16081615
else:
16091616
new_left = self._left.delete(loc)
1617+
assert not isinstance(self._right, np.ndarray)
16101618
new_right = self._right.delete(loc)
16111619
return self._shallow_copy(left=new_left, right=new_right)
16121620

@@ -1707,7 +1715,7 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
17071715
return isin(self.astype(object), values.astype(object))
17081716

17091717
@property
1710-
def _combined(self) -> ArrayLike:
1718+
def _combined(self) -> IntervalSideT:
17111719
left = self.left._values.reshape(-1, 1)
17121720
right = self.right._values.reshape(-1, 1)
17131721
if needs_i8_conversion(left.dtype):
@@ -1724,15 +1732,12 @@ def _from_combined(self, combined: np.ndarray) -> IntervalArray:
17241732

17251733
dtype = self._left.dtype
17261734
if needs_i8_conversion(dtype):
1727-
# error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence"
1728-
new_left = type(self._left)._from_sequence( # type: ignore[attr-defined]
1729-
nc[:, 0], dtype=dtype
1730-
)
1731-
# error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence"
1732-
new_right = type(self._right)._from_sequence( # type: ignore[attr-defined]
1733-
nc[:, 1], dtype=dtype
1734-
)
1735+
assert isinstance(self._left, (DatetimeArray, TimedeltaArray))
1736+
new_left = type(self._left)._from_sequence(nc[:, 0], dtype=dtype)
1737+
assert isinstance(self._right, (DatetimeArray, TimedeltaArray))
1738+
new_right = type(self._right)._from_sequence(nc[:, 1], dtype=dtype)
17351739
else:
1740+
assert isinstance(dtype, np.dtype)
17361741
new_left = nc[:, 0].view(dtype)
17371742
new_right = nc[:, 1].view(dtype)
17381743
return self._shallow_copy(left=new_left, right=new_right)

0 commit comments

Comments
 (0)