From 7098f1017d0a1bff683bcaefd82e832382ad7bdc Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Nov 2020 18:17:49 -0800 Subject: [PATCH 1/7] BUG: name retention in Index.intersection corner cases --- pandas/core/indexes/base.py | 2 +- pandas/core/indexes/datetimelike.py | 20 +++++++++++++------- pandas/core/indexes/interval.py | 6 +++++- pandas/core/indexes/multi.py | 13 +++++++++++-- pandas/core/indexes/period.py | 11 ++++++++--- pandas/core/indexes/range.py | 9 ++++++++- pandas/tests/indexes/test_setops.py | 14 ++++++++++++++ 7 files changed, 60 insertions(+), 15 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c49f3f9457161..24ce82a23bddf 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2820,7 +2820,7 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 1b18f04ba603d..666c39605ed3c 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -658,6 +658,16 @@ def difference(self, other, sort=None): return new_idx def intersection(self, other, sort=False): + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) + + if self.equals(other): + return self._get_reconciled_name_object(other) + + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): """ Specialized intersection for DatetimeIndex/TimedeltaIndex. @@ -684,11 +694,6 @@ def intersection(self, other, sort=False): ------- y : Index or same type as self """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - - if self.equals(other): - return self._get_reconciled_name_object(other) if len(self) == 0: return self.copy()._get_reconciled_name_object(other) @@ -704,10 +709,11 @@ def intersection(self, other, sort=False): return result elif not self._can_fast_intersect(other): - result = Index.intersection(self, other, sort=sort) - # We need to invalidate the freq because Index.intersection + result = Index._intersection(self, other, sort=sort) + # We need to invalidate the freq because Index._intersection # uses _shallow_copy on a view of self._data, which will preserve # self.freq if we're not careful. + result = self._wrap_setop_result(other, result) return result._with_freq(None)._with_freq("infer") # to make our life easier, "sort" the two ranges diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ed92b3dade6a0..65c841a63db91 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -122,8 +122,12 @@ def setop_check(method): @wraps(method) def wrapped(self, other, sort=False): + self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) + + if op_name == "intersection" and self.equals(other): + return self._get_reconciled_name_object(other) if not isinstance(other, IntervalIndex): result = getattr(self.astype(object), op_name)(other) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 9b4b459d9a122..4ff832abc60ad 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3576,6 +3576,17 @@ def union(self, other, sort=None): ) def intersection(self, other, sort=False): + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + + other, result_names = self._convert_can_do_setop(other) + + if self.equals(other): + return self._get_reconciled_name_object(other) + + return self._intersection(other, sort=sort).rename(result_names) + + def _intersection(self, other, sort=False): """ Form the intersection of two MultiIndex objects. @@ -3596,8 +3607,6 @@ def intersection(self, other, sort=False): ------- Index """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) other, result_names = self._convert_can_do_setop(other) if self.equals(other): diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 5dff07ee4c6dd..b9729843e1d60 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -625,9 +625,10 @@ def _setop(self, other, sort, opname: str): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - res_name = get_op_result_name(self, other) other = ensure_index(other) + res_name = get_op_result_name(self, other) + i8self = Int64Index._simple_new(self.asi8) i8other = Int64Index._simple_new(other.asi8) i8result = getattr(i8self, opname)(i8other, sort=sort) @@ -639,12 +640,16 @@ def _setop(self, other, sort, opname: str): def intersection(self, other, sort=False): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) - elif is_object_dtype(other.dtype): + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + + if is_object_dtype(other.dtype): return self.astype("O").intersection(other, sort=sort) elif not is_dtype_equal(self.dtype, other.dtype): diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 669bf115df104..9e2ad9cc25ac2 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -504,12 +504,19 @@ def intersection(self, other, sort=False): intersection : Index """ self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + if not isinstance(other, RangeIndex): - return super().intersection(other, sort=sort) + result = super()._intersection(other, sort=sort) + return self._wrap_setop_result(other, result) if not len(self) or not len(other): return self._simple_new(_empty_range) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 0973cef7cfdc1..54af607696226 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -378,6 +378,20 @@ def test_intersect_unequal(self, index, fname, sname, expected_name): expected = index[1:].set_names(expected_name).sort_values() tm.assert_index_equal(intersect, expected) + def test_intersection_name_retention_with_nameless(self, index): + if isinstance(index, MultiIndex): + index = index.rename(list(range(index.nlevels))) + else: + index = index.rename("foo") + + other = np.asarray(index) + + result = index.intersection(other) + assert result.name == index.name + + result = index.intersection(other[:0]) + assert result.name == index.name + def test_difference_preserves_type_empty(self, index, sort): # GH#20040 # If taking difference of a set and itself, it From 27704b02994d3e390226b36eacfb96cc04aa75a3 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Nov 2020 18:59:53 -0800 Subject: [PATCH 2/7] cleanup --- pandas/core/indexes/datetimelike.py | 20 +++++++++--------- pandas/core/indexes/interval.py | 5 +++-- pandas/core/indexes/multi.py | 21 ++++++++----------- pandas/core/indexes/period.py | 2 +- pandas/tests/indexes/datetimes/test_setops.py | 3 ++- 5 files changed, 25 insertions(+), 26 deletions(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 666c39605ed3c..45980a23cf58a 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -658,16 +658,6 @@ def difference(self, other, sort=None): return new_idx def intersection(self, other, sort=False): - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) - - if self.equals(other): - return self._get_reconciled_name_object(other) - - return self._intersection(other, sort=sort) - - def _intersection(self, other, sort=False): """ Specialized intersection for DatetimeIndex/TimedeltaIndex. @@ -694,6 +684,16 @@ def _intersection(self, other, sort=False): ------- y : Index or same type as self """ + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) + + if self.equals(other): + return self._get_reconciled_name_object(other) + + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): if len(self) == 0: return self.copy()._get_reconciled_name_object(other) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 65c841a63db91..86970563e3e5c 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -126,8 +126,9 @@ def wrapped(self, other, sort=False): self._assert_can_do_setop(other) other, _ = self._convert_can_do_setop(other) - if op_name == "intersection" and self.equals(other): - return self._get_reconciled_name_object(other) + if op_name == "intersection": + if self.equals(other): + return self._get_reconciled_name_object(other) if not isinstance(other, IntervalIndex): result = getattr(self.astype(object), op_name)(other) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 4ff832abc60ad..ed65aacbd186a 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3576,17 +3576,6 @@ def union(self, other, sort=None): ) def intersection(self, other, sort=False): - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - - other, result_names = self._convert_can_do_setop(other) - - if self.equals(other): - return self._get_reconciled_name_object(other) - - return self._intersection(other, sort=sort).rename(result_names) - - def _intersection(self, other, sort=False): """ Form the intersection of two MultiIndex objects. @@ -3607,10 +3596,18 @@ def _intersection(self, other, sort=False): ------- Index """ + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, result_names = self._convert_can_do_setop(other) if self.equals(other): - return self.rename(result_names) + return self._get_reconciled_name_object(other) + + return self._intersection(other, sort=sort).rename(result_names) + + def _intersection(self, other, sort=False): + other, result_names = self._convert_can_do_setop(other) if not is_object_dtype(other.dtype): # The intersection is empty diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index b9729843e1d60..cbf3e68fe7354 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -652,7 +652,7 @@ def _intersection(self, other, sort=False): if is_object_dtype(other.dtype): return self.astype("O").intersection(other, sort=sort) - elif not is_dtype_equal(self.dtype, other.dtype): + elif not self._is_comparable_dtype(other.dtype): # We can infer that the intersection is empty. # assert_can_do_setop ensures that this is not just a mismatched freq this = self[:0].astype("O") diff --git a/pandas/tests/indexes/datetimes/test_setops.py b/pandas/tests/indexes/datetimes/test_setops.py index c8edd30e3f7aa..3b6d29a15e7dc 100644 --- a/pandas/tests/indexes/datetimes/test_setops.py +++ b/pandas/tests/indexes/datetimes/test_setops.py @@ -471,10 +471,11 @@ def test_intersection_bug(self): def test_intersection_list(self): # GH#35876 + # values is not an Index -> no name -> retain "a" values = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")] idx = DatetimeIndex(values, name="a") res = idx.intersection(values) - tm.assert_index_equal(res, idx.rename(None)) + tm.assert_index_equal(res, idx) def test_month_range_union_tz_pytz(self, sort): from pytz import timezone From 5ae7748099f279b34a5d7cdc9c33372f36a52a6e Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 27 Nov 2020 07:29:09 -0800 Subject: [PATCH 3/7] standardize, tests --- pandas/core/indexes/multi.py | 8 ++++---- pandas/tests/indexes/test_setops.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index ed65aacbd186a..19cc75cb8683f 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3598,13 +3598,12 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - - other, result_names = self._convert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) - return self._intersection(other, sort=sort).rename(result_names) + return self._intersection(other, sort=sort) def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) @@ -3723,11 +3722,12 @@ def _convert_can_do_setop(self, other): levels=[[]] * self.nlevels, codes=[[]] * self.nlevels, verify_integrity=False, + names=self.names, ) else: msg = "other must be a MultiIndex or a list of tuples" try: - other = MultiIndex.from_tuples(other) + other = MultiIndex.from_tuples(other, names=self.names) except TypeError as err: raise TypeError(msg) from err else: diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 54af607696226..f3c1d1fbdd69a 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -98,13 +98,20 @@ def test_compatible_inconsistent_pairs(idx_fact1, idx_fact2): ("Period[D]", "float64", "object"), ], ) -def test_union_dtypes(left, right, expected): +@pytest.mark.parametrize("names", [("foo", "foo", "foo"), ("foo", "bar", None)]) +def test_union_dtypes(left, right, expected, names): left = pandas_dtype(left) right = pandas_dtype(right) - a = pd.Index([], dtype=left) - b = pd.Index([], dtype=right) - result = a.union(b).dtype - assert result == expected + a = pd.Index([], dtype=left, name=names[0]) + b = pd.Index([], dtype=right, name=names[1]) + result = a.union(b) + assert result.dtype == expected + assert result.name == names[2] + + # Testing name retention + # TODO: pin down desired dtype; do we want it to be commutative? + result = a.intersection(b) + assert result.name == names[2] def test_dunder_inplace_setops_deprecated(index): @@ -389,9 +396,14 @@ def test_intersection_name_retention_with_nameless(self, index): result = index.intersection(other) assert result.name == index.name + # empty other, same dtype result = index.intersection(other[:0]) assert result.name == index.name + # empty `self` + result = index[:0].intersection(other) + assert result.name == index.name + def test_difference_preserves_type_empty(self, index, sort): # GH#20040 # If taking difference of a set and itself, it From 453b1b6727252f03d1ea6fefcead0aab11a26329 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 27 Nov 2020 07:52:15 -0800 Subject: [PATCH 4/7] revert accidental --- pandas/core/indexes/period.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index cbf3e68fe7354..437f54d21be8a 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -625,9 +625,8 @@ def _setop(self, other, sort, opname: str): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) - res_name = get_op_result_name(self, other) + other = ensure_index(other) i8self = Int64Index._simple_new(self.asi8) i8other = Int64Index._simple_new(other.asi8) From 4ae4637fc943ec67963e14747524b2772fe7c1c9 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 28 Nov 2020 10:40:40 -0800 Subject: [PATCH 5/7] whatsnew --- doc/source/whatsnew/v1.2.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 6aff4f4bd41e2..4770607d85181 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -772,6 +772,7 @@ Other - Fixed bug in metadata propagation incorrectly copying DataFrame columns as metadata when the column name overlaps with the metadata name (:issue:`37037`) - Fixed metadata propagation in the :class:`Series.dt`, :class:`Series.str` accessors, :class:`DataFrame.duplicated`, :class:`DataFrame.stack`, :class:`DataFrame.unstack`, :class:`DataFrame.pivot`, :class:`DataFrame.append`, :class:`DataFrame.diff`, :class:`DataFrame.applymap` and :class:`DataFrame.update` methods (:issue:`28283`, :issue:`37381`) - Fixed metadata propagation when selecting columns with ``DataFrame.__getitem__`` (:issue:`28283`) +- Bug in :meth:`Index.intersection` with non-:class:`Index` failing to set the correct name on the returned :class:`Index` (:issue:`38111`) - Bug in :meth:`Index.union` behaving differently depending on whether operand is an :class:`Index` or other list-like (:issue:`36384`) - Passing an array with 2 or more dimensions to the :class:`Series` constructor now raises the more specific ``ValueError`` rather than a bare ``Exception`` (:issue:`35744`) - Bug in ``dir`` where ``dir(obj)`` wouldn't show attributes defined on the instance for pandas objects (:issue:`37173`) From dc8ae4a01c59ba79b013eb0118657866c6adb596 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 29 Nov 2020 11:00:07 -0800 Subject: [PATCH 6/7] Annotate+docstring --- pandas/core/indexes/datetimelike.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 6907aab312233..9b8703f5c2fff 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -693,8 +693,10 @@ def intersection(self, other, sort=False): return self._intersection(other, sort=sort) - def _intersection(self, other, sort=False): - + def _intersection(self, other: Index, sort=False) -> Index: + """ + intersection specialized to the case with matching dtypes. + """ if len(self) == 0: return self.copy()._get_reconciled_name_object(other) if len(other) == 0: From 71dbc0c2ff6c47a29a335baed4320c21fa82b9c6 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 30 Nov 2020 09:55:51 -0800 Subject: [PATCH 7/7] Fix RangeIndex failure --- pandas/core/indexes/range.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 9e2ad9cc25ac2..6380551fc202c 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -15,6 +15,7 @@ from pandas.core.dtypes.common import ( ensure_platform_int, ensure_python_int, + is_dtype_equal, is_float, is_integer, is_list_like, @@ -515,8 +516,11 @@ def intersection(self, other, sort=False): def _intersection(self, other, sort=False): if not isinstance(other, RangeIndex): - result = super()._intersection(other, sort=sort) - return self._wrap_setop_result(other, result) + if is_dtype_equal(other.dtype, self.dtype): + # Int64Index + result = super()._intersection(other, sort=sort) + return self._wrap_setop_result(other, result) + return super().intersection(other, sort=sort) if not len(self) or not len(other): return self._simple_new(_empty_range)