diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 9168041a4f474..10c52c00b3c4d 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -801,6 +801,7 @@ Other - Fixed bug in metadata propagation incorrectly copying DataFrame columns as metadata when the column name overlaps with the metadata name (:issue:`37037`) - Fixed metadata propagation in the :class:`Series.dt`, :class:`Series.str` accessors, :class:`DataFrame.duplicated`, :class:`DataFrame.stack`, :class:`DataFrame.unstack`, :class:`DataFrame.pivot`, :class:`DataFrame.append`, :class:`DataFrame.diff`, :class:`DataFrame.applymap` and :class:`DataFrame.update` methods (:issue:`28283`, :issue:`37381`) - Fixed metadata propagation when selecting columns with ``DataFrame.__getitem__`` (:issue:`28283`) +- Bug in :meth:`Index.intersection` with non-:class:`Index` failing to set the correct name on the returned :class:`Index` (:issue:`38111`) - Bug in :meth:`Index.union` behaving differently depending on whether operand is an :class:`Index` or other list-like (:issue:`36384`) - Bug in :meth:`Index.intersection` with non-matching numeric dtypes casting to ``object`` dtype instead of minimal common dtype (:issue:`38122`) - Passing an array with 2 or more dimensions to the :class:`Series` constructor now raises the more specific ``ValueError`` rather than a bare ``Exception`` (:issue:`35744`) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 09fe885e47754..141b626d15d9d 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2821,7 +2821,7 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other) and not self.has_duplicates: return self._get_reconciled_name_object(other) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 28ff5a8bacc71..9b8703f5c2fff 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -686,10 +686,17 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) + return self._intersection(other, sort=sort) + + def _intersection(self, other: Index, sort=False) -> Index: + """ + intersection specialized to the case with matching dtypes. + """ if len(self) == 0: return self.copy()._get_reconciled_name_object(other) if len(other) == 0: @@ -704,10 +711,11 @@ def intersection(self, other, sort=False): return result elif not self._can_fast_intersect(other): - result = Index.intersection(self, other, sort=sort) - # We need to invalidate the freq because Index.intersection + result = Index._intersection(self, other, sort=sort) + # We need to invalidate the freq because Index._intersection # uses _shallow_copy on a view of self._data, which will preserve # self.freq if we're not careful. + result = self._wrap_setop_result(other, result) return result._with_freq(None)._with_freq("infer") # to make our life easier, "sort" the two ranges diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index bd92926941aa1..dd9e16cb6cd5f 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -124,7 +124,11 @@ def setop_check(method): def wrapped(self, other, sort=False): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) + + if op_name == "intersection": + if self.equals(other): + return self._get_reconciled_name_object(other) if not isinstance(other, IntervalIndex): result = getattr(self.astype(object), op_name)(other) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 4aedf03ca1800..b9acb12890ecb 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3603,7 +3603,12 @@ def intersection(self, other, sort=False): if self.equals(other): if self.has_duplicates: return self.unique().rename(result_names) - return self.rename(result_names) + return self._get_reconciled_name_object(other) + + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + other, result_names = self._convert_can_do_setop(other) if not is_object_dtype(other.dtype): # The intersection is empty @@ -3721,7 +3726,7 @@ def _convert_can_do_setop(self, other): else: msg = "other must be a MultiIndex or a list of tuples" try: - other = MultiIndex.from_tuples(other) + other = MultiIndex.from_tuples(other, names=self.names) except (ValueError, TypeError) as err: # ValueError raised by tupels_to_object_array if we # have non-object dtype diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index b223e583d0ce0..3f70582be267c 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -639,15 +639,19 @@ def _setop(self, other, sort, opname: str): def intersection(self, other, sort=False): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other = ensure_index(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) - elif is_object_dtype(other.dtype): + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + + if is_object_dtype(other.dtype): return self.astype("O").intersection(other, sort=sort) - elif not is_dtype_equal(self.dtype, other.dtype): + elif not self._is_comparable_dtype(other.dtype): # We can infer that the intersection is empty. # assert_can_do_setop ensures that this is not just a mismatched freq this = self[:0].astype("O") diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 669bf115df104..6380551fc202c 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -15,6 +15,7 @@ from pandas.core.dtypes.common import ( ensure_platform_int, ensure_python_int, + is_dtype_equal, is_float, is_integer, is_list_like, @@ -504,11 +505,21 @@ def intersection(self, other, sort=False): intersection : Index """ self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + other, _ = self._convert_can_do_setop(other) if self.equals(other): return self._get_reconciled_name_object(other) + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + if not isinstance(other, RangeIndex): + if is_dtype_equal(other.dtype, self.dtype): + # Int64Index + result = super()._intersection(other, sort=sort) + return self._wrap_setop_result(other, result) return super().intersection(other, sort=sort) if not len(self) or not len(other): diff --git a/pandas/tests/indexes/datetimes/test_setops.py b/pandas/tests/indexes/datetimes/test_setops.py index c8edd30e3f7aa..3b6d29a15e7dc 100644 --- a/pandas/tests/indexes/datetimes/test_setops.py +++ b/pandas/tests/indexes/datetimes/test_setops.py @@ -471,10 +471,11 @@ def test_intersection_bug(self): def test_intersection_list(self): # GH#35876 + # values is not an Index -> no name -> retain "a" values = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")] idx = DatetimeIndex(values, name="a") res = idx.intersection(values) - tm.assert_index_equal(res, idx.rename(None)) + tm.assert_index_equal(res, idx) def test_month_range_union_tz_pytz(self, sort): from pytz import timezone diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 2675c4569a8e9..b6e793ba334ff 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -98,13 +98,20 @@ def test_compatible_inconsistent_pairs(idx_fact1, idx_fact2): ("Period[D]", "float64", "object"), ], ) -def test_union_dtypes(left, right, expected): +@pytest.mark.parametrize("names", [("foo", "foo", "foo"), ("foo", "bar", None)]) +def test_union_dtypes(left, right, expected, names): left = pandas_dtype(left) right = pandas_dtype(right) - a = pd.Index([], dtype=left) - b = pd.Index([], dtype=right) - result = a.union(b).dtype - assert result == expected + a = pd.Index([], dtype=left, name=names[0]) + b = pd.Index([], dtype=right, name=names[1]) + result = a.union(b) + assert result.dtype == expected + assert result.name == names[2] + + # Testing name retention + # TODO: pin down desired dtype; do we want it to be commutative? + result = a.intersection(b) + assert result.name == names[2] def test_dunder_inplace_setops_deprecated(index): @@ -388,6 +395,25 @@ def test_intersect_unequal(self, index, fname, sname, expected_name): expected = index[1:].set_names(expected_name).sort_values() tm.assert_index_equal(intersect, expected) + def test_intersection_name_retention_with_nameless(self, index): + if isinstance(index, MultiIndex): + index = index.rename(list(range(index.nlevels))) + else: + index = index.rename("foo") + + other = np.asarray(index) + + result = index.intersection(other) + assert result.name == index.name + + # empty other, same dtype + result = index.intersection(other[:0]) + assert result.name == index.name + + # empty `self` + result = index[:0].intersection(other) + assert result.name == index.name + def test_difference_preserves_type_empty(self, index, sort): # GH#20040 # If taking difference of a set and itself, it