Skip to content

ENH: Add support for ea dtypes in merge #49876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions asv_bench/benchmarks/join_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,38 @@ def time_merge_dataframes_cross(self, sort):
merge(self.left.loc[:2000], self.right.loc[:2000], how="cross", sort=sort)


class MergeEA:

params = [
"Int64",
"Int32",
"Int16",
"UInt64",
"UInt32",
"UInt16",
"Float64",
"Float32",
]
param_names = ["dtype"]

def setup(self, dtype):
N = 10_000
indices = np.arange(1, N)
key = np.tile(indices[:8000], 10)
self.left = DataFrame(
{"key": Series(key, dtype=dtype), "value": np.random.randn(80000)}
)
self.right = DataFrame(
{
"key": Series(indices[2000:], dtype=dtype),
"value2": np.random.randn(7999),
}
)

def time_merge(self, dtype):
merge(self.left, self.right)


class I8Merge:

params = ["inner", "outer", "left", "right"]
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Other enhancements
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` and :func:`read_excel` to enable automatic conversion to nullable dtypes (:issue:`36712`)
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
- Added support for extension array dtypes in :func:`merge` (:issue:`44240`)
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)
- Fix ``test`` optional_extra by adding missing test package ``pytest-asyncio`` (:issue:`48361`)
Expand Down
3 changes: 3 additions & 0 deletions pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
Any,
Hashable,
Literal,
)
Expand All @@ -13,6 +14,7 @@ def unique_label_indices(

class Factorizer:
count: int
uniques: Any
def __init__(self, size_hint: int) -> None: ...
def get_count(self) -> int: ...
def factorize(
Expand All @@ -21,6 +23,7 @@ class Factorizer:
sort: bool = ...,
na_sentinel=...,
na_value=...,
mask=...,
) -> npt.NDArray[np.intp]: ...

class ObjectFactorizer(Factorizer):
Expand Down
86 changes: 64 additions & 22 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.cast import find_common_type
from pandas.core.dtypes.common import (
ensure_float64,
Expand Down Expand Up @@ -79,7 +80,10 @@
Series,
)
import pandas.core.algorithms as algos
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays import (
BaseMaskedArray,
ExtensionArray,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
import pandas.core.common as com
from pandas.core.construction import extract_array
Expand All @@ -92,6 +96,24 @@
from pandas.core import groupby
from pandas.core.arrays import DatetimeArray

_factorizers = {
np.int64: libhashtable.Int64Factorizer,
np.longlong: libhashtable.Int64Factorizer,
np.int32: libhashtable.Int32Factorizer,
np.int16: libhashtable.Int16Factorizer,
np.int8: libhashtable.Int8Factorizer,
np.uint64: libhashtable.UInt64Factorizer,
np.uint32: libhashtable.UInt32Factorizer,
np.uint16: libhashtable.UInt16Factorizer,
np.uint8: libhashtable.UInt8Factorizer,
np.bool_: libhashtable.UInt8Factorizer,
np.float64: libhashtable.Float64Factorizer,
np.float32: libhashtable.Float32Factorizer,
np.complex64: libhashtable.Complex64Factorizer,
np.complex128: libhashtable.Complex128Factorizer,
np.object_: libhashtable.ObjectFactorizer,
}


@Substitution("\nleft : DataFrame or named Series")
@Appender(_merge_doc, indents=0)
Expand Down Expand Up @@ -2335,25 +2357,37 @@ def _factorize_keys(
rk = ensure_int64(rk.codes)

elif isinstance(lk, ExtensionArray) and is_dtype_equal(lk.dtype, rk.dtype):
lk, _ = lk._values_for_factorize()

# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
# "_values_for_factorize"
rk, _ = rk._values_for_factorize() # type: ignore[union-attr]

klass: type[libhashtable.Factorizer] | type[libhashtable.Int64Factorizer]
if is_integer_dtype(lk.dtype) and is_integer_dtype(rk.dtype):
# GH#23917 TODO: needs tests for case where lk is integer-dtype
# and rk is datetime-dtype
klass = libhashtable.Int64Factorizer
lk = ensure_int64(np.asarray(lk))
rk = ensure_int64(np.asarray(rk))

elif needs_i8_conversion(lk.dtype) and is_dtype_equal(lk.dtype, rk.dtype):
# GH#23917 TODO: Needs tests for non-matching dtypes
klass = libhashtable.Int64Factorizer
lk = ensure_int64(np.asarray(lk, dtype=np.int64))
rk = ensure_int64(np.asarray(rk, dtype=np.int64))
if not isinstance(lk, BaseMaskedArray):
lk, _ = lk._values_for_factorize()

# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
# "_values_for_factorize"
rk, _ = rk._values_for_factorize() # type: ignore[union-attr]

klass: type[libhashtable.Factorizer]
if is_numeric_dtype(lk.dtype):
if not is_dtype_equal(lk, rk):
dtype = find_common_type([lk.dtype, rk.dtype])
if isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
if not isinstance(lk, ExtensionArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of all of the if checks can we set up another function or class to dispatch to that handles this more gracefully? I think ideally could also eliminate the need for the factorizer dict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but would prefer to do this as a follow up.

Getting rid of the dict is probably not easy if doable at all

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not for the dict at least all these if checks I think should be wrapped in a separate function. We are losing a good bit of abstraction here compared to the previous code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep this makes sense, but as I said would prefer doing this as a follow up.

lk = cls._from_sequence(lk, dtype=dtype, copy=False)
else:
lk = lk.astype(dtype)

if not isinstance(rk, ExtensionArray):
rk = cls._from_sequence(rk, dtype=dtype, copy=False)
else:
rk = rk.astype(dtype)
else:
lk = lk.astype(dtype)
rk = rk.astype(dtype)
if isinstance(lk, BaseMaskedArray):
# Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]";
# expected type "Type[object]"
klass = _factorizers[lk.dtype.type] # type: ignore[index]
else:
klass = _factorizers[lk.dtype.type]

else:
klass = libhashtable.ObjectFactorizer
Expand All @@ -2362,8 +2396,16 @@ def _factorize_keys(

rizer = klass(max(len(lk), len(rk)))

llab = rizer.factorize(lk)
rlab = rizer.factorize(rk)
if isinstance(lk, BaseMaskedArray):
assert isinstance(rk, BaseMaskedArray)
llab = rizer.factorize(lk._data, mask=lk._mask)
rlab = rizer.factorize(rk._data, mask=rk._mask)
else:
# Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
# "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
# ndarray[Any, dtype[object_]]]"; expected "ndarray[Any, dtype[object_]]"
llab = rizer.factorize(lk) # type: ignore[arg-type]
rlab = rizer.factorize(rk) # type: ignore[arg-type]
assert llab.dtype == np.dtype(np.intp), llab.dtype
assert rlab.dtype == np.dtype(np.intp), rlab.dtype

Expand Down
24 changes: 24 additions & 0 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2714,3 +2714,27 @@ def test_merge_different_index_names():
result = merge(left, right, left_on="c", right_on="d")
expected = DataFrame({"a_x": [1], "a_y": 1})
tm.assert_frame_equal(result, expected)


def test_merge_ea(any_numeric_ea_dtype, join_type):
# GH#44240
left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype)
right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype)
result = left.merge(right, how=join_type)
expected = DataFrame({"a": [1, 2, 3], "b": 1, "c": 2}, dtype=any_numeric_ea_dtype)
tm.assert_frame_equal(result, expected)


def test_merge_ea_and_non_ea(any_numeric_ea_dtype, join_type):
# GH#44240
left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype)
right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype.lower())
result = left.merge(right, how=join_type)
expected = DataFrame(
{
"a": Series([1, 2, 3], dtype=any_numeric_ea_dtype),
"b": Series([1, 1, 1], dtype=any_numeric_ea_dtype),
"c": Series([2, 2, 2], dtype=any_numeric_ea_dtype.lower()),
}
)
tm.assert_frame_equal(result, expected)