diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index e07a8fa0469f4..283345d554a51 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -210,7 +210,7 @@ Other ^^^^^ - Appending a dictionary to a :class:`DataFrame` without passing ``ignore_index=True`` will raise ``TypeError: Can only append a dict if ignore_index=True`` instead of ``TypeError: Can only append a Series if ignore_index=True or if the Series has a name`` (:issue:`30871`) -- +- Set operations on an object-dtype :class:`Index` now always return object-dtype results (:issue:`31401`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 22ba317e78e63..729e49e7ea76a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -522,6 +522,7 @@ def _shallow_copy(self, values=None, **kwargs): values = self.values attributes = self._get_attributes_dict() + attributes.update(kwargs) return self._simple_new(values, **attributes) @@ -2566,6 +2567,7 @@ def _union(self, other, sort): # worth making this faster? a very unusual case value_set = set(lvals) result.extend([x for x in rvals if x not in value_set]) + result = Index(result)._values # do type inference here else: # find indexes of things in "other" that are not in "self" if self.is_unique: @@ -2595,7 +2597,8 @@ def _union(self, other, sort): return self._wrap_setop_result(other, result) def _wrap_setop_result(self, other, result): - return self._constructor(result, name=get_op_result_name(self, other)) + name = get_op_result_name(self, other) + return self._shallow_copy(result, name=name) # TODO: standardize return type of non-union setops type(self vs other) def intersection(self, other, sort=False): @@ -2652,9 +2655,10 @@ def intersection(self, other, sort=False): if self.is_monotonic and other.is_monotonic: try: result = self._inner_indexer(lvals, rvals)[0] - return self._wrap_setop_result(other, result) except TypeError: pass + else: + return self._wrap_setop_result(other, result) try: indexer = Index(rvals).get_indexer(lvals) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index d556c014467cf..3a0ca4e76b8a6 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -29,7 +29,6 @@ from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name from pandas.core.indexes.extension import ExtensionIndex, inherit_names import pandas.core.missing as missing -from pandas.core.ops import get_op_result_name _index_doc_kwargs = dict(ibase._index_doc_kwargs) _index_doc_kwargs.update(dict(target_klass="CategoricalIndex")) @@ -386,12 +385,6 @@ def _has_complex_internals(self) -> bool: # used to avoid libreduction code paths, which raise or require conversion return True - def _wrap_setop_result(self, other, result): - name = get_op_result_name(self, other) - # We use _shallow_copy rather than the Index implementation - # (which uses _constructor) in order to preserve dtype. - return self._shallow_copy(result, name=name) - @Appender(Index.__contains__.__doc__) def __contains__(self, key: Any) -> bool: # if key is a NaN, check if any NaN is in self. diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index e3eeca2c45e76..d3bd65080132f 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -789,11 +789,10 @@ def _union(self, other, sort): if this._can_fast_union(other): return this._fast_union(other, sort=sort) else: - result = Index._union(this, other, sort=sort) - if isinstance(result, type(self)): - assert result._data.dtype == this.dtype - if result.freq is None: - result._set_freq("infer") + i8self = Int64Index._simple_new(self.asi8, name=self.name) + i8other = Int64Index._simple_new(other.asi8, name=other.name) + i8result = i8self._union(i8other, sort=sort) + result = type(self)(i8result, dtype=self.dtype, freq="infer") return result # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 2b4636155111f..6828b2ca96add 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -29,7 +29,6 @@ from pandas.core.indexes.base import Index, InvalidIndexError, maybe_extract_name from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin from pandas.core.indexes.extension import inherit_names -from pandas.core.ops import get_op_result_name import pandas.core.tools.datetimes as tools from pandas.tseries.frequencies import Resolution, to_offset @@ -348,18 +347,9 @@ def union_many(self, others): if this._can_fast_union(other): this = this._fast_union(other) else: - dtype = this.dtype this = Index.union(this, other) - if isinstance(this, DatetimeIndex): - # TODO: we shouldn't be setting attributes like this; - # in all the tests this equality already holds - this._data._dtype = dtype return this - def _wrap_setop_result(self, other, result): - name = get_op_result_name(self, other) - return self._shallow_copy(result, name=name, freq=None) - # -------------------------------------------------------------------- def _get_time_micros(self): diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index e72963de09ab4..811bbe4eddfa9 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -1047,6 +1047,32 @@ def test_setops_disallow_true(self, method): with pytest.raises(ValueError, match="The 'sort' keyword only takes"): getattr(idx1, method)(idx2, sort=True) + def test_setops_preserve_object_dtype(self): + idx = pd.Index([1, 2, 3], dtype=object) + result = idx.intersection(idx[1:]) + expected = idx[1:] + tm.assert_index_equal(result, expected) + + # if other is not monotonic increasing, intersection goes through + # a different route + result = idx.intersection(idx[1:][::-1]) + tm.assert_index_equal(result, expected) + + result = idx._union(idx[1:], sort=None) + expected = idx + tm.assert_index_equal(result, expected) + + result = idx.union(idx[1:], sort=None) + tm.assert_index_equal(result, expected) + + # if other is not monotonic increasing, _union goes through + # a different route + result = idx._union(idx[1:][::-1], sort=None) + tm.assert_index_equal(result, expected) + + result = idx.union(idx[1:][::-1], sort=None) + tm.assert_index_equal(result, expected) + def test_map_identity_mapping(self, indices): # GH 12766 tm.assert_index_equal(indices, indices.map(lambda x: x))