Skip to content

ENH: IntervalArray.min/max #44746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ Other enhancements
``USFederalHolidayCalendar``. See also `Other API changes`_.
- :meth:`.Rolling.var`, :meth:`.Expanding.var`, :meth:`.Rolling.std`, :meth:`.Expanding.std` now support `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`44461`)
- :meth:`Series.info` has been added, for compatibility with :meth:`DataFrame.info` (:issue:`5167`)
- Implemented :meth:`IntervalArray.min`, :meth:`IntervalArray.max`, as a result of which ``min`` and ``max`` now work for :class:`IntervalIndex`, :class:`Series` and :class:`DataFrame` with ``IntervalDtype`` (:issue:`44746`)
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
-

Expand Down
30 changes: 30 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,36 @@ def argsort(
ascending=ascending, kind=kind, na_position=na_position, **kwargs
)

def min(self, *, axis: int | None = None, skipna: bool = True):
nv.validate_minmax_axis(axis, self.ndim)

if not len(self):
return self._na_value

mask = self.isna()
if mask.any():
if not skipna:
return self._na_value
return self[~mask].min()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u do

Suggested change
return self[~mask].min()
self = self[~mask]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. follow-up OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


indexer = self.argsort()[0]
return self[indexer]

def max(self, *, axis: int | None = None, skipna: bool = True):
nv.validate_minmax_axis(axis, self.ndim)

if not len(self):
return self._na_value

mask = self.isna()
if mask.any():
if not skipna:
return self._na_value
return self[~mask].max()

indexer = self.argsort()[-1]
return self[indexer]

def fillna(
self: IntervalArrayT, value=None, method=None, limit=None
) -> IntervalArrayT:
Expand Down
67 changes: 67 additions & 0 deletions pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def test_shift_datetime(self):
class TestSetitem:
def test_set_na(self, left_right_dtypes):
left, right = left_right_dtypes
left = left.copy(deep=True)
right = right.copy(deep=True)
result = IntervalArray.from_arrays(left, right)

if result.dtype.subtype.kind not in ["m", "M"]:
Expand Down Expand Up @@ -161,6 +163,71 @@ def test_repr():
assert result == expected


class TestReductions:
def test_min_max_invalid_axis(self, left_right_dtypes):
left, right = left_right_dtypes
left = left.copy(deep=True)
right = right.copy(deep=True)
arr = IntervalArray.from_arrays(left, right)

msg = "`axis` must be fewer than the number of dimensions"
for axis in [-2, 1]:
with pytest.raises(ValueError, match=msg):
arr.min(axis=axis)
with pytest.raises(ValueError, match=msg):
arr.max(axis=axis)

msg = "'>=' not supported between"
with pytest.raises(TypeError, match=msg):
arr.min(axis="foo")
with pytest.raises(TypeError, match=msg):
arr.max(axis="foo")

def test_min_max(self, left_right_dtypes, index_or_series_or_array):
# GH#44746
left, right = left_right_dtypes
left = left.copy(deep=True)
right = right.copy(deep=True)
arr = IntervalArray.from_arrays(left, right)

# The expected results below are only valid if monotonic
assert left.is_monotonic_increasing
assert Index(arr).is_monotonic_increasing

MIN = arr[0]
MAX = arr[-1]

indexer = np.arange(len(arr))
np.random.shuffle(indexer)
arr = arr.take(indexer)

arr_na = arr.insert(2, np.nan)

arr = index_or_series_or_array(arr)
arr_na = index_or_series_or_array(arr_na)

for skipna in [True, False]:
res = arr.min(skipna=skipna)
assert res == MIN
assert type(res) == type(MIN)

res = arr.max(skipna=skipna)
assert res == MAX
assert type(res) == type(MAX)

res = arr_na.min(skipna=False)
assert np.isnan(res)
res = arr_na.max(skipna=False)
assert np.isnan(res)

res = arr_na.min(skipna=True)
assert res == MIN
assert type(res) == type(MIN)
res = arr_na.max(skipna=True)
assert res == MAX
assert type(res) == type(MAX)


# ----------------------------------------------------------------------------
# Arrow interaction

Expand Down
25 changes: 21 additions & 4 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from pandas.core.dtypes.dtypes import IntervalDtype

from pandas import Interval
from pandas import (
Interval,
Series,
)
from pandas.core.arrays import IntervalArray
from pandas.tests.extension import base

Expand Down Expand Up @@ -99,17 +102,31 @@ class TestInterface(BaseInterval, base.BaseInterfaceTests):


class TestReduce(base.BaseNoReduceTests):
pass
@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
op_name = all_numeric_reductions
ser = Series(data)

if op_name in ["min", "max"]:
# IntervalArray *does* implement these
assert getattr(ser, op_name)(skipna=skipna) in data
assert getattr(data, op_name)(skipna=skipna) in data
return

super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)


class TestMethods(BaseInterval, base.BaseMethodsTests):
@pytest.mark.skip(reason="addition is not defined for intervals")
def test_combine_add(self, data_repeated):
pass

@pytest.mark.skip(reason="Not Applicable")
@pytest.mark.xfail(
reason="Raises with incorrect message bc it disallows *all* listlikes "
"instead of just wrong-length listlikes"
)
def test_fillna_length_mismatch(self, data_missing):
pass
super().test_fillna_length_mismatch(data_missing)


class TestMissing(BaseInterval, base.BaseMissingTests):
Expand Down
8 changes: 0 additions & 8 deletions pandas/tests/series/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,6 @@ def test_reduce(values, box, request):
if values.dtype.kind in ["i", "f"]:
# ATM Index casts to object, so we get python ints/floats
same_type = False
elif isinstance(values, pd.IntervalIndex):
mark = pytest.mark.xfail(reason="IntervalArray.min/max not implemented")
request.node.add_marker(mark)

elif box is pd.Series or box is pd.DataFrame:
if isinstance(values, pd.IntervalIndex):
mark = pytest.mark.xfail(reason="IntervalArray.min/max not implemented")
request.node.add_marker(mark)

if values.dtype == "i8" and box is pd.array:
# FIXME: pd.array casts to Int64
Expand Down