Skip to content

REF: share IntervalIndex.intersection with Index.intersection #38373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,15 @@ def __from_arrow__(
results.append(iarr)

return IntervalArray._concat_same_type(results)

def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
# NB: this doesn't handle checking for closed match
if not all(isinstance(x, IntervalDtype) for x in dtypes):
return np.dtype(object)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the wrong return value, see docstring in the base class:

def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
"""
Return the common dtype, if one exists.
Used in `find_common_type` implementation. This is for example used
to determine the resulting dtype in a concat operation.
If no common dtype exists, return None (which gives the other dtypes
the chance to determine a common dtype). If all dtypes in the list
return None, then the common dtype will be "object" dtype (this means
it is never needed to return "object" dtype from this method itself).
Parameters
----------
dtypes : list of dtypes
The dtypes for which to determine a common dtype. This is a list
of np.dtype or ExtensionDtype instances.
Returns
-------
Common dtype (np.dtype or ExtensionDtype) or None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, will address in follow-up


from pandas.core.dtypes.cast import find_common_type

common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes])
if common == object:
return np.dtype(object)
return IntervalDtype(common)
23 changes: 4 additions & 19 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,11 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
return not is_object_dtype(common_subtype)

def _should_compare(self, other) -> bool:
if not super()._should_compare(other):
return False
other = unpack_nested_dtype(other)
if is_object_dtype(other.dtype):
return True
if not self._is_comparable_dtype(other.dtype):
return False
return other.closed == self.closed

# TODO: use should_compare and get rid of _is_non_comparable_own_type
Expand Down Expand Up @@ -972,23 +974,6 @@ def _assert_can_do_setop(self, other):
"and have compatible dtypes"
)

@Appender(Index.intersection.__doc__)
def intersection(self, other, sort=False) -> Index:
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, _ = self._convert_can_do_setop(other)

if self.equals(other):
if self.has_duplicates:
return self.unique()._get_reconciled_name_object(other)
return self._get_reconciled_name_object(other)

if not isinstance(other, IntervalIndex):
return self.astype(object).intersection(other)

result = self._intersection(other, sort=sort)
return self._wrap_setop_result(other, result)

def _intersection(self, other, sort):
"""
intersection specialized to the case with matching dtypes.
Expand Down
38 changes: 37 additions & 1 deletion pandas/tests/dtypes/cast/test_find_common_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import pytest

from pandas.core.dtypes.cast import find_common_type
from pandas.core.dtypes.dtypes import CategoricalDtype, DatetimeTZDtype, PeriodDtype
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
DatetimeTZDtype,
IntervalDtype,
PeriodDtype,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -120,3 +125,34 @@ def test_period_dtype_mismatch(dtype2):
dtype = PeriodDtype(freq="D")
assert find_common_type([dtype, dtype2]) == object
assert find_common_type([dtype2, dtype]) == object


interval_dtypes = [
IntervalDtype(np.int64),
IntervalDtype(np.float64),
IntervalDtype(np.uint64),
IntervalDtype(DatetimeTZDtype(unit="ns", tz="US/Eastern")),
IntervalDtype("M8[ns]"),
IntervalDtype("m8[ns]"),
]


@pytest.mark.parametrize("left", interval_dtypes)
@pytest.mark.parametrize("right", interval_dtypes)
def test_interval_dtype(left, right):
result = find_common_type([left, right])

if left is right:
assert result is left

elif left.subtype.kind in ["i", "u", "f"]:
# i.e. numeric
if right.subtype.kind in ["i", "u", "f"]:
# both numeric -> common numeric subtype
expected = IntervalDtype(np.float64)
assert result == expected
else:
assert result == object

else:
assert result == object
30 changes: 19 additions & 11 deletions pandas/tests/indexes/interval/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@ def test_intersection(self, closed, sort):

tm.assert_index_equal(index.intersection(index, sort=sort), index)

# GH 19101: empty result, same dtype
other = monotonic_index(300, 314, closed=closed)
expected = empty_index(dtype="int64", closed=closed)
result = index.intersection(other, sort=sort)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = monotonic_index(300, 314, dtype="float64", closed=closed)
result = index.intersection(other, sort=sort)
tm.assert_index_equal(result, expected)

# GH 26225: nested intervals
index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
Expand Down Expand Up @@ -100,6 +89,25 @@ def test_intersection(self, closed, sort):
result = index.intersection(other)
tm.assert_index_equal(result, expected)

def test_intersection_empty_result(self, closed, sort):
index = monotonic_index(0, 11, closed=closed)

# GH 19101: empty result, same dtype
other = monotonic_index(300, 314, closed=closed)
expected = empty_index(dtype="int64", closed=closed)
result = index.intersection(other, sort=sort)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different numeric dtypes -> common dtype is float64
other = monotonic_index(300, 314, dtype="float64", closed=closed)
result = index.intersection(other, sort=sort)
expected = other[:0]
tm.assert_index_equal(result, expected)

other = monotonic_index(300, 314, dtype="uint64", closed=closed)
result = index.intersection(other, sort=sort)
tm.assert_index_equal(result, expected)

def test_difference(self, closed, sort):
index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
result = index.difference(index[:1], sort=sort)
Expand Down