Skip to content

BUG: preserve object dtype for Index set ops #31401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 2, 2020
Merged
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Other
^^^^^
- Appending a dictionary to a :class:`DataFrame` without passing ``ignore_index=True`` will raise ``TypeError: Can only append a dict if ignore_index=True``
instead of ``TypeError: Can only append a Series if ignore_index=True or if the Series has a name`` (:issue:`30871`)
-
- Set operations on an object-dtype :class:`Index` now always return object-dtype results (:issue:`31401`)

.. ---------------------------------------------------------------------------

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def _shallow_copy(self, values=None, **kwargs):
values = self.values

attributes = self._get_attributes_dict()

attributes.update(kwargs)

return self._simple_new(values, **attributes)
Expand Down Expand Up @@ -2566,6 +2567,7 @@ def _union(self, other, sort):
# worth making this faster? a very unusual case
value_set = set(lvals)
result.extend([x for x in rvals if x not in value_set])
result = Index(result)._values # do type inference here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does _wrap_setop_result not do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ATM it does in a roundabout manner. Index._wrap_setop_result calls self._constructor, which is type(self) whereas the subclasses use _shallow_copy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

else:
# find indexes of things in "other" that are not in "self"
if self.is_unique:
Expand Down Expand Up @@ -2595,7 +2597,8 @@ def _union(self, other, sort):
return self._wrap_setop_result(other, result)

def _wrap_setop_result(self, other, result):
return self._constructor(result, name=get_op_result_name(self, other))
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name)

# TODO: standardize return type of non-union setops type(self vs other)
def intersection(self, other, sort=False):
Expand Down Expand Up @@ -2652,9 +2655,10 @@ def intersection(self, other, sort=False):
if self.is_monotonic and other.is_monotonic:
try:
result = self._inner_indexer(lvals, rvals)[0]
return self._wrap_setop_result(other, result)
except TypeError:
pass
else:
return self._wrap_setop_result(other, result)

try:
indexer = Index(rvals).get_indexer(lvals)
Expand Down
7 changes: 0 additions & 7 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name
from pandas.core.indexes.extension import ExtensionIndex, inherit_names
import pandas.core.missing as missing
from pandas.core.ops import get_op_result_name

_index_doc_kwargs = dict(ibase._index_doc_kwargs)
_index_doc_kwargs.update(dict(target_klass="CategoricalIndex"))
Expand Down Expand Up @@ -386,12 +385,6 @@ def _has_complex_internals(self) -> bool:
# used to avoid libreduction code paths, which raise or require conversion
return True

def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
# We use _shallow_copy rather than the Index implementation
# (which uses _constructor) in order to preserve dtype.
return self._shallow_copy(result, name=name)

@Appender(Index.__contains__.__doc__)
def __contains__(self, key: Any) -> bool:
# if key is a NaN, check if any NaN is in self.
Expand Down
9 changes: 4 additions & 5 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,11 +789,10 @@ def _union(self, other, sort):
if this._can_fast_union(other):
return this._fast_union(other, sort=sort)
else:
result = Index._union(this, other, sort=sort)
if isinstance(result, type(self)):
assert result._data.dtype == this.dtype
if result.freq is None:
result._set_freq("infer")
i8self = Int64Index._simple_new(self.asi8, name=self.name)
i8other = Int64Index._simple_new(other.asi8, name=other.name)
i8result = i8self._union(i8other, sort=sort)
result = type(self)(i8result, dtype=self.dtype, freq="infer")
return result

# --------------------------------------------------------------------
Expand Down
10 changes: 0 additions & 10 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pandas.core.indexes.base import Index, InvalidIndexError, maybe_extract_name
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
from pandas.core.indexes.extension import inherit_names
from pandas.core.ops import get_op_result_name
import pandas.core.tools.datetimes as tools

from pandas.tseries.frequencies import Resolution, to_offset
Expand Down Expand Up @@ -348,18 +347,9 @@ def union_many(self, others):
if this._can_fast_union(other):
this = this._fast_union(other)
else:
dtype = this.dtype
this = Index.union(this, other)
if isinstance(this, DatetimeIndex):
# TODO: we shouldn't be setting attributes like this;
# in all the tests this equality already holds
this._data._dtype = dtype
return this

def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name, freq=None)

# --------------------------------------------------------------------

def _get_time_micros(self):
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,32 @@ def test_setops_disallow_true(self, method):
with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
getattr(idx1, method)(idx2, sort=True)

def test_setops_preserve_object_dtype(self):
idx = pd.Index([1, 2, 3], dtype=object)
result = idx.intersection(idx[1:])
expected = idx[1:]
tm.assert_index_equal(result, expected)

# if other is not monotonic increasing, intersection goes through
# a different route
result = idx.intersection(idx[1:][::-1])
tm.assert_index_equal(result, expected)

result = idx._union(idx[1:], sort=None)
expected = idx
tm.assert_index_equal(result, expected)

result = idx.union(idx[1:], sort=None)
tm.assert_index_equal(result, expected)

# if other is not monotonic increasing, _union goes through
# a different route
result = idx._union(idx[1:][::-1], sort=None)
tm.assert_index_equal(result, expected)

result = idx.union(idx[1:][::-1], sort=None)
tm.assert_index_equal(result, expected)

def test_map_identity_mapping(self, indices):
# GH 12766
tm.assert_index_equal(indices, indices.map(lambda x: x))
Expand Down