From 7798fdc13144211eb5947787304231157e29b9b8 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 14:21:28 -0800 Subject: [PATCH 1/5] REF: share IntervalIndex.intersection with Index.intersection --- pandas/core/dtypes/dtypes.py | 12 ++++++ pandas/core/indexes/interval.py | 23 ++--------- .../dtypes/cast/test_find_common_type.py | 38 ++++++++++++++++++- pandas/tests/indexes/interval/test_setops.py | 30 +++++++++------ 4 files changed, 72 insertions(+), 31 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 3c5421ae433b6..4544708120f9e 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1171,3 +1171,15 @@ def __from_arrow__( results.append(iarr) return IntervalArray._concat_same_type(results) + + def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]: + # NB: this doesn't handle checking for closed match + if not all(isinstance(x, IntervalDtype) for x in dtypes): + return np.dtype(object) + + from pandas.core.dtypes.cast import find_common_type + + common = find_common_type([x.subtype for x in dtypes]) + if common == object: + return np.dtype(object) + return IntervalDtype(common) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ee25a9d81a60f..c79dabe9eb1de 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -810,9 +810,11 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: return not is_object_dtype(common_subtype) def _should_compare(self, other) -> bool: - if not super()._should_compare(other): - return False other = unpack_nested_dtype(other) + if is_object_dtype(other.dtype): + return True + if not self._is_comparable_dtype(other.dtype): + return False return other.closed == self.closed # TODO: use should_compare and get rid of _is_non_comparable_own_type @@ -972,23 +974,6 @@ def _assert_can_do_setop(self, other): "and have compatible dtypes" ) - @Appender(Index.intersection.__doc__) - def intersection(self, other, sort=False) -> Index: - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - if not isinstance(other, IntervalIndex): - return self.astype(object).intersection(other) - - result = self._intersection(other, sort=sort) - return self._wrap_setop_result(other, result) - def _intersection(self, other, sort): """ intersection specialized to the case with matching dtypes. diff --git a/pandas/tests/dtypes/cast/test_find_common_type.py b/pandas/tests/dtypes/cast/test_find_common_type.py index 8dac92f469703..7b1aa12dc0cc4 100644 --- a/pandas/tests/dtypes/cast/test_find_common_type.py +++ b/pandas/tests/dtypes/cast/test_find_common_type.py @@ -2,7 +2,12 @@ import pytest from pandas.core.dtypes.cast import find_common_type -from pandas.core.dtypes.dtypes import CategoricalDtype, DatetimeTZDtype, PeriodDtype +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + IntervalDtype, + PeriodDtype, +) @pytest.mark.parametrize( @@ -120,3 +125,34 @@ def test_period_dtype_mismatch(dtype2): dtype = PeriodDtype(freq="D") assert find_common_type([dtype, dtype2]) == object assert find_common_type([dtype2, dtype]) == object + + +interval_dtypes = [ + IntervalDtype(np.int64), + IntervalDtype(np.float64), + IntervalDtype(np.uint64), + IntervalDtype(DatetimeTZDtype(unit="ns", tz="US/Eastern")), + IntervalDtype("M8[ns]"), + IntervalDtype("m8[ns]"), +] + + +@pytest.mark.parametrize("left", interval_dtypes) +@pytest.mark.parametrize("right", interval_dtypes) +def test_interval_dtype(left, right): + result = find_common_type([left, right]) + + if left is right: + assert result is left + + elif left.subtype.kind in ["i", "u", "f"]: + # i.e. numeric + if right.subtype.kind in ["i", "u", "f"]: + # both numeric -> common numeric subtype + expected = IntervalDtype(np.float64) + assert result == expected + else: + assert result == object + + else: + assert result == object diff --git a/pandas/tests/indexes/interval/test_setops.py b/pandas/tests/indexes/interval/test_setops.py index 0ef833bb93ded..278c27b302ef9 100644 --- a/pandas/tests/indexes/interval/test_setops.py +++ b/pandas/tests/indexes/interval/test_setops.py @@ -61,17 +61,6 @@ def test_intersection(self, closed, sort): tm.assert_index_equal(index.intersection(index, sort=sort), index) - # GH 19101: empty result, same dtype - other = monotonic_index(300, 314, closed=closed) - expected = empty_index(dtype="int64", closed=closed) - result = index.intersection(other, sort=sort) - tm.assert_index_equal(result, expected) - - # GH 19101: empty result, different dtypes - other = monotonic_index(300, 314, dtype="float64", closed=closed) - result = index.intersection(other, sort=sort) - tm.assert_index_equal(result, expected) - # GH 26225: nested intervals index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)]) other = IntervalIndex.from_tuples([(1, 2), (1, 3)]) @@ -100,6 +89,25 @@ def test_intersection(self, closed, sort): result = index.intersection(other) tm.assert_index_equal(result, expected) + def test_intersection_empty_result(self, closed, sort): + index = monotonic_index(0, 11, closed=closed) + + # GH 19101: empty result, same dtype + other = monotonic_index(300, 314, closed=closed) + expected = empty_index(dtype="int64", closed=closed) + result = index.intersection(other, sort=sort) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different numeric dtypes -> common dtype is float64 + other = monotonic_index(300, 314, dtype="float64", closed=closed) + result = index.intersection(other, sort=sort) + expected = other[:0] + tm.assert_index_equal(result, expected) + + other = monotonic_index(300, 314, dtype="uint64", closed=closed) + result = index.intersection(other, sort=sort) + tm.assert_index_equal(result, expected) + def test_difference(self, closed, sort): index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed) result = index.difference(index[:1], sort=sort) From 14147977a0f1e7cc0fa5bc5dee4dd4df2861b78b Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 16:40:34 -0800 Subject: [PATCH 2/5] mypy fixup --- pandas/core/dtypes/dtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 4544708120f9e..dd4c7c1eedb49 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1179,7 +1179,7 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]: from pandas.core.dtypes.cast import find_common_type - common = find_common_type([x.subtype for x in dtypes]) + common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes]) if common == object: return np.dtype(object) return IntervalDtype(common) From d9e7c1e4151ea6a29a6ce83ea90bf56c93907bd1 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 17:43:14 -0800 Subject: [PATCH 3/5] REF: use find_common_type in Index.union --- pandas/core/indexes/base.py | 61 +++++++++-------------------- pandas/core/indexes/datetimelike.py | 3 -- pandas/core/indexes/numeric.py | 25 ------------ pandas/core/indexes/timedeltas.py | 2 +- 4 files changed, 19 insertions(+), 72 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 11b7acc0a9deb..ac9f57fa16e63 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2592,47 +2592,6 @@ def _get_reconciled_name_object(self, other): return self.rename(name) return self - @final - def _union_incompatible_dtypes(self, other, sort): - """ - Casts this and other index to object dtype to allow the formation - of a union between incompatible types. - - Parameters - ---------- - other : Index or array-like - sort : False or None, default False - Whether to sort the resulting index. - - * False : do not sort the result. - * None : sort the result, except when `self` and `other` are equal - or when the values cannot be compared. - - Returns - ------- - Index - """ - this = self.astype(object, copy=False) - # cast to Index for when `other` is list-like - other = Index(other).astype(object, copy=False) - return Index.union(this, other, sort=sort).astype(object, copy=False) - - def _can_union_without_object_cast(self, other) -> bool: - """ - Check whether this and the other dtype are compatible with each other. - Meaning a union can be formed between them without needing to be cast - to dtype object. - - Parameters - ---------- - other : Index or array-like - - Returns - ------- - bool - """ - return type(self) is type(other) and is_dtype_equal(self.dtype, other.dtype) - @final def _validate_sort_keyword(self, sort): if sort not in [None, False]: @@ -2696,8 +2655,24 @@ def union(self, other, sort=None): self._assert_can_do_setop(other) other, result_name = self._convert_can_do_setop(other) - if not self._can_union_without_object_cast(other): - return self._union_incompatible_dtypes(other, sort=sort) + if not is_dtype_equal(self.dtype, other.dtype): + dtype = find_common_type([self.dtype, other.dtype]) + if self._is_numeric_dtype and other._is_numeric_dtype: + # Right now, we treat union(int, float) a bit special. + # See https://github.com/pandas-dev/pandas/issues/26778 for discussion + # We may change union(int, float) to go to object. + # float | [u]int -> float (the special case) + # | -> T + # | -> object + if not (is_integer_dtype(self.dtype) and is_integer_dtype(other.dtype)): + dtype = "float64" + else: + # one is int64 other is uint64 + dtype = object + + left = self.astype(dtype, copy=False) + right = other.astype(dtype, copy=False) + return left.union(right, sort=sort) result = self._union(other, sort=sort) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index f0d4d36531e0d..220cd5363e78f 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -597,9 +597,6 @@ def insert(self, loc: int, item): # -------------------------------------------------------------------- # Join/Set Methods - def _can_union_without_object_cast(self, other) -> bool: - return is_dtype_equal(self.dtype, other.dtype) - def _get_join_freq(self, other): """ Get the freq to attach to the result of a join operation. diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index ed76e26a57634..a31e7de9ec768 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -177,23 +177,6 @@ def insert(self, loc: int, item): return super().insert(loc, item) - def _union(self, other, sort): - # Right now, we treat union(int, float) a bit special. - # See https://github.com/pandas-dev/pandas/issues/26778 for discussion - # We may change union(int, float) to go to object. - # float | [u]int -> float (the special case) - # | -> T - # | -> object - needs_cast = (is_integer_dtype(self.dtype) and is_float_dtype(other.dtype)) or ( - is_integer_dtype(other.dtype) and is_float_dtype(self.dtype) - ) - if needs_cast: - first = self.astype("float") - second = other.astype("float") - return first._union(second, sort) - else: - return super()._union(other, sort) - _num_index_shared_docs[ "class_descr" @@ -253,10 +236,6 @@ def _assert_safe_casting(cls, data, subarr): if not np.array_equal(data, subarr): raise TypeError("Unsafe NumPy casting, you must explicitly cast") - def _can_union_without_object_cast(self, other) -> bool: - # See GH#26778, further casting may occur in NumericIndex._union - return other.dtype == "f8" or other.dtype == self.dtype - def __contains__(self, key) -> bool: """ Check if key is a float and has a decimal. If it has, return False. @@ -417,7 +396,3 @@ def __contains__(self, other: Any) -> bool: return True return is_float(other) and np.isnan(other) and self.hasnans - - def _can_union_without_object_cast(self, other) -> bool: - # See GH#26778, further casting may occur in NumericIndex._union - return is_numeric_dtype(other.dtype) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index fcab3e1f6a0a4..24cc2ae0c4d48 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -108,7 +108,7 @@ class TimedeltaIndex(DatetimeTimedeltaMixin): _comparables = ["name", "freq"] _attributes = ["name", "freq"] - _is_numeric_dtype = True + _is_numeric_dtype = False _data: TimedeltaArray From bed4db189190c7aae515274b6dd756a029ca46c5 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 18:48:50 -0800 Subject: [PATCH 4/5] TST: update IntervalIndex.union tests --- pandas/tests/indexes/interval/test_setops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pandas/tests/indexes/interval/test_setops.py b/pandas/tests/indexes/interval/test_setops.py index 278c27b302ef9..13f50ae4bab7a 100644 --- a/pandas/tests/indexes/interval/test_setops.py +++ b/pandas/tests/indexes/interval/test_setops.py @@ -38,10 +38,14 @@ def test_union_empty_result(self, closed, sort): result = index.union(index, sort=sort) tm.assert_index_equal(result, index) - # GH 19101: empty result, different dtypes -> common dtype is object + # GH 19101: empty result, different numeric dtypes -> common dtype is f8 other = empty_index(dtype="float64", closed=closed) result = index.union(other, sort=sort) - expected = Index([], dtype=object) + expected = other + tm.assert_index_equal(result, expected) + + other = empty_index(dtype="uint64", closed=closed) + result = index.union(other, sort=sort) tm.assert_index_equal(result, expected) def test_intersection(self, closed, sort): From 12bb20265ce670ba8b9a16fa406f66d32465f6ee Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 13 Dec 2020 21:33:38 -0800 Subject: [PATCH 5/5] test reversed ops --- pandas/tests/indexes/interval/test_setops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pandas/tests/indexes/interval/test_setops.py b/pandas/tests/indexes/interval/test_setops.py index 13f50ae4bab7a..7bfe81e0645cb 100644 --- a/pandas/tests/indexes/interval/test_setops.py +++ b/pandas/tests/indexes/interval/test_setops.py @@ -44,10 +44,16 @@ def test_union_empty_result(self, closed, sort): expected = other tm.assert_index_equal(result, expected) + other = index.union(index, sort=sort) + tm.assert_index_equal(result, expected) + other = empty_index(dtype="uint64", closed=closed) result = index.union(other, sort=sort) tm.assert_index_equal(result, expected) + result = other.union(index, sort=sort) + tm.assert_index_equal(result, expected) + def test_intersection(self, closed, sort): index = monotonic_index(0, 11, closed=closed) other = monotonic_index(5, 13, closed=closed)