Skip to content

REF: unify Index union methods #38382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 14, 2020
61 changes: 18 additions & 43 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,47 +2592,6 @@ def _get_reconciled_name_object(self, other):
return self.rename(name)
return self

@final
def _union_incompatible_dtypes(self, other, sort):
"""
Casts this and other index to object dtype to allow the formation
of a union between incompatible types.

Parameters
----------
other : Index or array-like
sort : False or None, default False
Whether to sort the resulting index.

* False : do not sort the result.
* None : sort the result, except when `self` and `other` are equal
or when the values cannot be compared.

Returns
-------
Index
"""
this = self.astype(object, copy=False)
# cast to Index for when `other` is list-like
other = Index(other).astype(object, copy=False)
return Index.union(this, other, sort=sort).astype(object, copy=False)

def _can_union_without_object_cast(self, other) -> bool:
"""
Check whether this and the other dtype are compatible with each other.
Meaning a union can be formed between them without needing to be cast
to dtype object.

Parameters
----------
other : Index or array-like

Returns
-------
bool
"""
return type(self) is type(other) and is_dtype_equal(self.dtype, other.dtype)

@final
def _validate_sort_keyword(self, sort):
if sort not in [None, False]:
Expand Down Expand Up @@ -2696,8 +2655,24 @@ def union(self, other, sort=None):
self._assert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if not self._can_union_without_object_cast(other):
return self._union_incompatible_dtypes(other, sort=sort)
if not is_dtype_equal(self.dtype, other.dtype):
dtype = find_common_type([self.dtype, other.dtype])
if self._is_numeric_dtype and other._is_numeric_dtype:
# Right now, we treat union(int, float) a bit special.
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
# We may change union(int, float) to go to object.
# float | [u]int -> float (the special case)
# <T> | <T> -> T
# <T> | <U> -> object
if not (is_integer_dtype(self.dtype) and is_integer_dtype(other.dtype)):
dtype = "float64"
else:
# one is int64 other is uint64
dtype = object

left = self.astype(dtype, copy=False)
right = other.astype(dtype, copy=False)
return left.union(right, sort=sort)

result = self._union(other, sort=sort)

Expand Down
3 changes: 0 additions & 3 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,6 @@ def insert(self, loc: int, item):
# --------------------------------------------------------------------
# Join/Set Methods

def _can_union_without_object_cast(self, other) -> bool:
return is_dtype_equal(self.dtype, other.dtype)

def _get_join_freq(self, other):
"""
Get the freq to attach to the result of a join operation.
Expand Down
25 changes: 0 additions & 25 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,23 +182,6 @@ def _is_all_dates(self) -> bool:
"""
return False

def _union(self, other, sort):
# Right now, we treat union(int, float) a bit special.
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion
# We may change union(int, float) to go to object.
# float | [u]int -> float (the special case)
# <T> | <T> -> T
# <T> | <U> -> object
needs_cast = (is_integer_dtype(self.dtype) and is_float_dtype(other.dtype)) or (
is_integer_dtype(other.dtype) and is_float_dtype(self.dtype)
)
if needs_cast:
first = self.astype("float")
second = other.astype("float")
return first._union(second, sort)
else:
return super()._union(other, sort)


_num_index_shared_docs[
"class_descr"
Expand Down Expand Up @@ -258,10 +241,6 @@ def _assert_safe_casting(cls, data, subarr):
if not np.array_equal(data, subarr):
raise TypeError("Unsafe NumPy casting, you must explicitly cast")

def _can_union_without_object_cast(self, other) -> bool:
# See GH#26778, further casting may occur in NumericIndex._union
return other.dtype == "f8" or other.dtype == self.dtype

def __contains__(self, key) -> bool:
"""
Check if key is a float and has a decimal. If it has, return False.
Expand Down Expand Up @@ -422,7 +401,3 @@ def __contains__(self, other: Any) -> bool:
return True

return is_float(other) and np.isnan(other) and self.hasnans

def _can_union_without_object_cast(self, other) -> bool:
# See GH#26778, further casting may occur in NumericIndex._union
return is_numeric_dtype(other.dtype)
2 changes: 1 addition & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class TimedeltaIndex(DatetimeTimedeltaMixin):

_comparables = ["name", "freq"]
_attributes = ["name", "freq"]
_is_numeric_dtype = True
_is_numeric_dtype = False

_data: TimedeltaArray

Expand Down
14 changes: 12 additions & 2 deletions pandas/tests/indexes/interval/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,20 @@ def test_union_empty_result(self, closed, sort):
result = index.union(index, sort=sort)
tm.assert_index_equal(result, index)

# GH 19101: empty result, different dtypes -> common dtype is object
# GH 19101: empty result, different numeric dtypes -> common dtype is f8
other = empty_index(dtype="float64", closed=closed)
result = index.union(other, sort=sort)
expected = Index([], dtype=object)
expected = other
tm.assert_index_equal(result, expected)

other = index.union(index, sort=sort)
tm.assert_index_equal(result, expected)

other = empty_index(dtype="uint64", closed=closed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test the reverse as well

result = index.union(other, sort=sort)
tm.assert_index_equal(result, expected)

result = other.union(index, sort=sort)
tm.assert_index_equal(result, expected)

def test_intersection(self, closed, sort):
Expand Down