Skip to content

Commit e6f0d1d

Browse files
authored
ENH: IntervalArray.min/max (pandas-dev#44746)
1 parent b7ac9be commit e6f0d1d

File tree

5 files changed

+119
-12
lines changed

5 files changed

+119
-12
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ Other enhancements
227227
``USFederalHolidayCalendar``. See also `Other API changes`_.
228228
- :meth:`.Rolling.var`, :meth:`.Expanding.var`, :meth:`.Rolling.std`, :meth:`.Expanding.std` now support `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`44461`)
229229
- :meth:`Series.info` has been added, for compatibility with :meth:`DataFrame.info` (:issue:`5167`)
230+
- Implemented :meth:`IntervalArray.min`, :meth:`IntervalArray.max`, as a result of which ``min`` and ``max`` now work for :class:`IntervalIndex`, :class:`Series` and :class:`DataFrame` with ``IntervalDtype`` (:issue:`44746`)
230231
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
231232
-
232233

pandas/core/arrays/interval.py

+30
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,36 @@ def argsort(
790790
ascending=ascending, kind=kind, na_position=na_position, **kwargs
791791
)
792792

793+
def min(self, *, axis: int | None = None, skipna: bool = True):
794+
nv.validate_minmax_axis(axis, self.ndim)
795+
796+
if not len(self):
797+
return self._na_value
798+
799+
mask = self.isna()
800+
if mask.any():
801+
if not skipna:
802+
return self._na_value
803+
return self[~mask].min()
804+
805+
indexer = self.argsort()[0]
806+
return self[indexer]
807+
808+
def max(self, *, axis: int | None = None, skipna: bool = True):
809+
nv.validate_minmax_axis(axis, self.ndim)
810+
811+
if not len(self):
812+
return self._na_value
813+
814+
mask = self.isna()
815+
if mask.any():
816+
if not skipna:
817+
return self._na_value
818+
return self[~mask].max()
819+
820+
indexer = self.argsort()[-1]
821+
return self[indexer]
822+
793823
def fillna(
794824
self: IntervalArrayT, value=None, method=None, limit=None
795825
) -> IntervalArrayT:

pandas/tests/arrays/interval/test_interval.py

+67
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def test_shift_datetime(self):
103103
class TestSetitem:
104104
def test_set_na(self, left_right_dtypes):
105105
left, right = left_right_dtypes
106+
left = left.copy(deep=True)
107+
right = right.copy(deep=True)
106108
result = IntervalArray.from_arrays(left, right)
107109

108110
if result.dtype.subtype.kind not in ["m", "M"]:
@@ -161,6 +163,71 @@ def test_repr():
161163
assert result == expected
162164

163165

166+
class TestReductions:
167+
def test_min_max_invalid_axis(self, left_right_dtypes):
168+
left, right = left_right_dtypes
169+
left = left.copy(deep=True)
170+
right = right.copy(deep=True)
171+
arr = IntervalArray.from_arrays(left, right)
172+
173+
msg = "`axis` must be fewer than the number of dimensions"
174+
for axis in [-2, 1]:
175+
with pytest.raises(ValueError, match=msg):
176+
arr.min(axis=axis)
177+
with pytest.raises(ValueError, match=msg):
178+
arr.max(axis=axis)
179+
180+
msg = "'>=' not supported between"
181+
with pytest.raises(TypeError, match=msg):
182+
arr.min(axis="foo")
183+
with pytest.raises(TypeError, match=msg):
184+
arr.max(axis="foo")
185+
186+
def test_min_max(self, left_right_dtypes, index_or_series_or_array):
187+
# GH#44746
188+
left, right = left_right_dtypes
189+
left = left.copy(deep=True)
190+
right = right.copy(deep=True)
191+
arr = IntervalArray.from_arrays(left, right)
192+
193+
# The expected results below are only valid if monotonic
194+
assert left.is_monotonic_increasing
195+
assert Index(arr).is_monotonic_increasing
196+
197+
MIN = arr[0]
198+
MAX = arr[-1]
199+
200+
indexer = np.arange(len(arr))
201+
np.random.shuffle(indexer)
202+
arr = arr.take(indexer)
203+
204+
arr_na = arr.insert(2, np.nan)
205+
206+
arr = index_or_series_or_array(arr)
207+
arr_na = index_or_series_or_array(arr_na)
208+
209+
for skipna in [True, False]:
210+
res = arr.min(skipna=skipna)
211+
assert res == MIN
212+
assert type(res) == type(MIN)
213+
214+
res = arr.max(skipna=skipna)
215+
assert res == MAX
216+
assert type(res) == type(MAX)
217+
218+
res = arr_na.min(skipna=False)
219+
assert np.isnan(res)
220+
res = arr_na.max(skipna=False)
221+
assert np.isnan(res)
222+
223+
res = arr_na.min(skipna=True)
224+
assert res == MIN
225+
assert type(res) == type(MIN)
226+
res = arr_na.max(skipna=True)
227+
assert res == MAX
228+
assert type(res) == type(MAX)
229+
230+
164231
# ----------------------------------------------------------------------------
165232
# Arrow interaction
166233

pandas/tests/extension/test_interval.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
from pandas.core.dtypes.dtypes import IntervalDtype
2020

21-
from pandas import Interval
21+
from pandas import (
22+
Interval,
23+
Series,
24+
)
2225
from pandas.core.arrays import IntervalArray
2326
from pandas.tests.extension import base
2427

@@ -99,17 +102,31 @@ class TestInterface(BaseInterval, base.BaseInterfaceTests):
99102

100103

101104
class TestReduce(base.BaseNoReduceTests):
102-
pass
105+
@pytest.mark.parametrize("skipna", [True, False])
106+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
107+
op_name = all_numeric_reductions
108+
ser = Series(data)
109+
110+
if op_name in ["min", "max"]:
111+
# IntervalArray *does* implement these
112+
assert getattr(ser, op_name)(skipna=skipna) in data
113+
assert getattr(data, op_name)(skipna=skipna) in data
114+
return
115+
116+
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
103117

104118

105119
class TestMethods(BaseInterval, base.BaseMethodsTests):
106120
@pytest.mark.skip(reason="addition is not defined for intervals")
107121
def test_combine_add(self, data_repeated):
108122
pass
109123

110-
@pytest.mark.skip(reason="Not Applicable")
124+
@pytest.mark.xfail(
125+
reason="Raises with incorrect message bc it disallows *all* listlikes "
126+
"instead of just wrong-length listlikes"
127+
)
111128
def test_fillna_length_mismatch(self, data_missing):
112-
pass
129+
super().test_fillna_length_mismatch(data_missing)
113130

114131

115132
class TestMissing(BaseInterval, base.BaseMissingTests):

pandas/tests/series/test_ufunc.py

-8
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,6 @@ def test_reduce(values, box, request):
277277
if values.dtype.kind in ["i", "f"]:
278278
# ATM Index casts to object, so we get python ints/floats
279279
same_type = False
280-
elif isinstance(values, pd.IntervalIndex):
281-
mark = pytest.mark.xfail(reason="IntervalArray.min/max not implemented")
282-
request.node.add_marker(mark)
283-
284-
elif box is pd.Series or box is pd.DataFrame:
285-
if isinstance(values, pd.IntervalIndex):
286-
mark = pytest.mark.xfail(reason="IntervalArray.min/max not implemented")
287-
request.node.add_marker(mark)
288280

289281
if values.dtype == "i8" and box is pd.array:
290282
# FIXME: pd.array casts to Int64

0 commit comments

Comments
 (0)