Skip to content

Commit cf8c680

Browse files
authored
BUG: preserve object dtype for Index set ops (#31401)
1 parent b7d3d83 commit cf8c680

File tree

6 files changed

+37
-25
lines changed

6 files changed

+37
-25
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ Other
220220
^^^^^
221221
- Appending a dictionary to a :class:`DataFrame` without passing ``ignore_index=True`` will raise ``TypeError: Can only append a dict if ignore_index=True``
222222
instead of ``TypeError: Can only append a Series if ignore_index=True or if the Series has a name`` (:issue:`30871`)
223-
-
223+
- Set operations on an object-dtype :class:`Index` now always return object-dtype results (:issue:`31401`)
224224

225225
.. ---------------------------------------------------------------------------
226226

pandas/core/indexes/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def _shallow_copy(self, values=None, **kwargs):
526526
values = self.values
527527

528528
attributes = self._get_attributes_dict()
529+
529530
attributes.update(kwargs)
530531

531532
return self._simple_new(values, **attributes)
@@ -2570,6 +2571,7 @@ def _union(self, other, sort):
25702571
# worth making this faster? a very unusual case
25712572
value_set = set(lvals)
25722573
result.extend([x for x in rvals if x not in value_set])
2574+
result = Index(result)._values # do type inference here
25732575
else:
25742576
# find indexes of things in "other" that are not in "self"
25752577
if self.is_unique:
@@ -2599,7 +2601,8 @@ def _union(self, other, sort):
25992601
return self._wrap_setop_result(other, result)
26002602

26012603
def _wrap_setop_result(self, other, result):
2602-
return self._constructor(result, name=get_op_result_name(self, other))
2604+
name = get_op_result_name(self, other)
2605+
return self._shallow_copy(result, name=name)
26032606

26042607
# TODO: standardize return type of non-union setops type(self vs other)
26052608
def intersection(self, other, sort=False):
@@ -2656,9 +2659,10 @@ def intersection(self, other, sort=False):
26562659
if self.is_monotonic and other.is_monotonic:
26572660
try:
26582661
result = self._inner_indexer(lvals, rvals)[0]
2659-
return self._wrap_setop_result(other, result)
26602662
except TypeError:
26612663
pass
2664+
else:
2665+
return self._wrap_setop_result(other, result)
26622666

26632667
try:
26642668
indexer = Index(rvals).get_indexer(lvals)

pandas/core/indexes/category.py

-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name
2929
from pandas.core.indexes.extension import ExtensionIndex, inherit_names
3030
import pandas.core.missing as missing
31-
from pandas.core.ops import get_op_result_name
3231

3332
if TYPE_CHECKING:
3433
from pandas import Series
@@ -388,12 +387,6 @@ def _has_complex_internals(self) -> bool:
388387
# used to avoid libreduction code paths, which raise or require conversion
389388
return True
390389

391-
def _wrap_setop_result(self, other, result):
392-
name = get_op_result_name(self, other)
393-
# We use _shallow_copy rather than the Index implementation
394-
# (which uses _constructor) in order to preserve dtype.
395-
return self._shallow_copy(result, name=name)
396-
397390
@Appender(Index.__contains__.__doc__)
398391
def __contains__(self, key: Any) -> bool:
399392
# if key is a NaN, check if any NaN is in self.

pandas/core/indexes/datetimelike.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -801,11 +801,10 @@ def _union(self, other, sort):
801801
if this._can_fast_union(other):
802802
return this._fast_union(other, sort=sort)
803803
else:
804-
result = Index._union(this, other, sort=sort)
805-
if isinstance(result, type(self)):
806-
assert result._data.dtype == this.dtype
807-
if result.freq is None:
808-
result._set_freq("infer")
804+
i8self = Int64Index._simple_new(self.asi8, name=self.name)
805+
i8other = Int64Index._simple_new(other.asi8, name=other.name)
806+
i8result = i8self._union(i8other, sort=sort)
807+
result = type(self)(i8result, dtype=self.dtype, freq="infer")
809808
return result
810809

811810
# --------------------------------------------------------------------

pandas/core/indexes/datetimes.py

-10
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pandas.core.indexes.base import Index, InvalidIndexError, maybe_extract_name
3030
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
3131
from pandas.core.indexes.extension import inherit_names
32-
from pandas.core.ops import get_op_result_name
3332
import pandas.core.tools.datetimes as tools
3433

3534
from pandas.tseries.frequencies import Resolution, to_offset
@@ -347,18 +346,9 @@ def union_many(self, others):
347346
if this._can_fast_union(other):
348347
this = this._fast_union(other)
349348
else:
350-
dtype = this.dtype
351349
this = Index.union(this, other)
352-
if isinstance(this, DatetimeIndex):
353-
# TODO: we shouldn't be setting attributes like this;
354-
# in all the tests this equality already holds
355-
this._data._dtype = dtype
356350
return this
357351

358-
def _wrap_setop_result(self, other, result):
359-
name = get_op_result_name(self, other)
360-
return self._shallow_copy(result, name=name, freq=None)
361-
362352
# --------------------------------------------------------------------
363353

364354
def _get_time_micros(self):

pandas/tests/indexes/test_base.py

+26
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,32 @@ def test_setops_disallow_true(self, method):
10471047
with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
10481048
getattr(idx1, method)(idx2, sort=True)
10491049

1050+
def test_setops_preserve_object_dtype(self):
1051+
idx = pd.Index([1, 2, 3], dtype=object)
1052+
result = idx.intersection(idx[1:])
1053+
expected = idx[1:]
1054+
tm.assert_index_equal(result, expected)
1055+
1056+
# if other is not monotonic increasing, intersection goes through
1057+
# a different route
1058+
result = idx.intersection(idx[1:][::-1])
1059+
tm.assert_index_equal(result, expected)
1060+
1061+
result = idx._union(idx[1:], sort=None)
1062+
expected = idx
1063+
tm.assert_index_equal(result, expected)
1064+
1065+
result = idx.union(idx[1:], sort=None)
1066+
tm.assert_index_equal(result, expected)
1067+
1068+
# if other is not monotonic increasing, _union goes through
1069+
# a different route
1070+
result = idx._union(idx[1:][::-1], sort=None)
1071+
tm.assert_index_equal(result, expected)
1072+
1073+
result = idx.union(idx[1:][::-1], sort=None)
1074+
tm.assert_index_equal(result, expected)
1075+
10501076
def test_map_identity_mapping(self, indices):
10511077
# GH 12766
10521078
tm.assert_index_equal(indices, indices.map(lambda x: x))

0 commit comments

Comments
 (0)