Skip to content

Commit dad651d

Browse files
authored
pre-commit edit (#40617)
1 parent db01b30 commit dad651d

File tree

2 files changed

+81
-58
lines changed

2 files changed

+81
-58
lines changed

pandas/core/indexes/extension.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,38 @@ class ExtensionIndex(Index):
231231

232232
_data: Union[IntervalArray, NDArrayBackedExtensionArray]
233233

234+
_data_cls: Union[
235+
Type[Categorical],
236+
Type[DatetimeArray],
237+
Type[TimedeltaArray],
238+
Type[PeriodArray],
239+
Type[IntervalArray],
240+
]
241+
242+
@classmethod
243+
def _simple_new(
244+
cls,
245+
array: Union[IntervalArray, NDArrayBackedExtensionArray],
246+
name: Hashable = None,
247+
):
248+
"""
249+
Construct from an ExtensionArray of the appropriate type.
250+
251+
Parameters
252+
----------
253+
array : ExtensionArray
254+
name : Label, default None
255+
Attached as result.name
256+
"""
257+
assert isinstance(array, cls._data_cls), type(array)
258+
259+
result = object.__new__(cls)
260+
result._data = array
261+
result._name = name
262+
result._cache = {}
263+
result._reset_identity()
264+
return result
265+
234266
__eq__ = _make_wrapped_comparison_op("__eq__")
235267
__ne__ = _make_wrapped_comparison_op("__ne__")
236268
__lt__ = _make_wrapped_comparison_op("__lt__")
@@ -362,30 +394,17 @@ class NDArrayBackedExtensionIndex(ExtensionIndex):
362394

363395
_data: NDArrayBackedExtensionArray
364396

365-
_data_cls: Union[
366-
Type[Categorical],
367-
Type[DatetimeArray],
368-
Type[TimedeltaArray],
369-
Type[PeriodArray],
370-
]
371-
372397
@classmethod
373398
def _simple_new(
374399
cls,
375400
values: NDArrayBackedExtensionArray,
376401
name: Hashable = None,
377402
):
378-
assert isinstance(values, cls._data_cls), type(values)
379-
380-
result = object.__new__(cls)
381-
result._data = values
382-
result._name = name
383-
result._cache = {}
403+
result = super()._simple_new(values, name)
384404

385405
# For groupby perf. See note in indexes/base about _index_data
386406
result._index_data = values._ndarray
387407

388-
result._reset_identity()
389408
return result
390409

391410
def _get_engine_target(self) -> np.ndarray:

pandas/core/indexes/interval.py

+48-44
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,31 @@ def wrapped(self, other, sort=False):
187187
return wrapped
188188

189189

190+
def _setop(op_name: str):
191+
"""
192+
Implement set operation.
193+
"""
194+
195+
def func(self, other, sort=None):
196+
# At this point we are assured
197+
# isinstance(other, IntervalIndex)
198+
# other.closed == self.closed
199+
200+
result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort)
201+
result_name = get_op_result_name(self, other)
202+
203+
# GH 19101: ensure empty results have correct dtype
204+
if result.empty:
205+
result = result._values.astype(self.dtype.subtype)
206+
else:
207+
result = result._values
208+
209+
return type(self).from_tuples(result, closed=self.closed, name=result_name)
210+
211+
func.__name__ = op_name
212+
return setop_check(func)
213+
214+
190215
@Appender(
191216
_interval_shared_docs["class"]
192217
% {
@@ -218,19 +243,38 @@ def wrapped(self, other, sort=False):
218243
}
219244
)
220245
@inherit_names(["set_closed", "to_tuples"], IntervalArray, wrap=True)
221-
@inherit_names(["__array__", "overlaps", "contains"], IntervalArray)
246+
@inherit_names(
247+
[
248+
"__array__",
249+
"overlaps",
250+
"contains",
251+
"closed_left",
252+
"closed_right",
253+
"open_left",
254+
"open_right",
255+
"is_empty",
256+
],
257+
IntervalArray,
258+
)
222259
@inherit_names(["is_non_overlapping_monotonic", "closed"], IntervalArray, cache=True)
223-
class IntervalIndex(IntervalMixin, ExtensionIndex):
260+
class IntervalIndex(ExtensionIndex):
224261
_typ = "intervalindex"
225262
_comparables = ["name"]
226263
_attributes = ["name", "closed"]
227264

265+
# annotate properties pinned via inherit_names
266+
closed: str
267+
is_non_overlapping_monotonic: bool
268+
closed_left: bool
269+
closed_right: bool
270+
228271
# we would like our indexing holder to defer to us
229272
_defer_to_indexing = True
230273

231274
_data: IntervalArray
232275
_values: IntervalArray
233276
_can_hold_strings = False
277+
_data_cls = IntervalArray
234278

235279
# --------------------------------------------------------------------
236280
# Constructors
@@ -241,7 +285,7 @@ def __new__(
241285
closed=None,
242286
dtype: Optional[Dtype] = None,
243287
copy: bool = False,
244-
name=None,
288+
name: Hashable = None,
245289
verify_integrity: bool = True,
246290
):
247291

@@ -258,26 +302,6 @@ def __new__(
258302

259303
return cls._simple_new(array, name)
260304

261-
@classmethod
262-
def _simple_new(cls, array: IntervalArray, name: Hashable = None):
263-
"""
264-
Construct from an IntervalArray
265-
266-
Parameters
267-
----------
268-
array : IntervalArray
269-
name : Label, default None
270-
Attached as result.name
271-
"""
272-
assert isinstance(array, IntervalArray), type(array)
273-
274-
result = IntervalMixin.__new__(cls)
275-
result._data = array
276-
result.name = name
277-
result._cache = {}
278-
result._reset_identity()
279-
return result
280-
281305
@classmethod
282306
@Appender(
283307
_interval_shared_docs["from_breaks"]
@@ -605,7 +629,7 @@ def _searchsorted_monotonic(self, label, side: str = "left"):
605629
"non-overlapping and all monotonic increasing or decreasing"
606630
)
607631

608-
if isinstance(label, IntervalMixin):
632+
if isinstance(label, (IntervalMixin, IntervalIndex)):
609633
raise NotImplementedError("Interval objects are not currently supported")
610634

611635
# GH 20921: "not is_monotonic_increasing" for the second condition
@@ -1012,26 +1036,6 @@ def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
10121036

10131037
return self[mask]
10141038

1015-
def _setop(op_name: str, sort=None):
1016-
def func(self, other, sort=sort):
1017-
# At this point we are assured
1018-
# isinstance(other, IntervalIndex)
1019-
# other.closed == self.closed
1020-
1021-
result = getattr(self._multiindex, op_name)(other._multiindex, sort=sort)
1022-
result_name = get_op_result_name(self, other)
1023-
1024-
# GH 19101: ensure empty results have correct dtype
1025-
if result.empty:
1026-
result = result._values.astype(self.dtype.subtype)
1027-
else:
1028-
result = result._values
1029-
1030-
return type(self).from_tuples(result, closed=self.closed, name=result_name)
1031-
1032-
func.__name__ = op_name
1033-
return setop_check(func)
1034-
10351039
_union = _setop("union")
10361040
_difference = _setop("difference")
10371041

0 commit comments

Comments
 (0)