Skip to content

Commit c09ac01

Browse files
authored
ENH: Add support for ea dtypes in merge (#49876)
1 parent 7374a0d commit c09ac01

File tree

5 files changed

+124
-22
lines changed

5 files changed

+124
-22
lines changed

asv_bench/benchmarks/join_merge.py

+32
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,38 @@ def time_merge_dataframes_cross(self, sort):
273273
merge(self.left.loc[:2000], self.right.loc[:2000], how="cross", sort=sort)
274274

275275

276+
class MergeEA:
277+
278+
params = [
279+
"Int64",
280+
"Int32",
281+
"Int16",
282+
"UInt64",
283+
"UInt32",
284+
"UInt16",
285+
"Float64",
286+
"Float32",
287+
]
288+
param_names = ["dtype"]
289+
290+
def setup(self, dtype):
291+
N = 10_000
292+
indices = np.arange(1, N)
293+
key = np.tile(indices[:8000], 10)
294+
self.left = DataFrame(
295+
{"key": Series(key, dtype=dtype), "value": np.random.randn(80000)}
296+
)
297+
self.right = DataFrame(
298+
{
299+
"key": Series(indices[2000:], dtype=dtype),
300+
"value2": np.random.randn(7999),
301+
}
302+
)
303+
304+
def time_merge(self, dtype):
305+
merge(self.left, self.right)
306+
307+
276308
class I8Merge:
277309

278310
params = ["inner", "outer", "left", "right"]

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Other enhancements
5757
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
5858
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` and :func:`read_excel` to enable automatic conversion to nullable dtypes (:issue:`36712`)
5959
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
60+
- Added support for extension array dtypes in :func:`merge` (:issue:`44240`)
6061
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
6162
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)
6263
- Fix ``test`` optional_extra by adding missing test package ``pytest-asyncio`` (:issue:`48361`)

pandas/_libs/hashtable.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import (
2+
Any,
23
Hashable,
34
Literal,
45
)
@@ -13,6 +14,7 @@ def unique_label_indices(
1314

1415
class Factorizer:
1516
count: int
17+
uniques: Any
1618
def __init__(self, size_hint: int) -> None: ...
1719
def get_count(self) -> int: ...
1820
def factorize(
@@ -21,6 +23,7 @@ class Factorizer:
2123
sort: bool = ...,
2224
na_sentinel=...,
2325
na_value=...,
26+
mask=...,
2427
) -> npt.NDArray[np.intp]: ...
2528

2629
class ObjectFactorizer(Factorizer):

pandas/core/reshape/merge.py

+64-22
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from pandas.util._exceptions import find_stack_level
4444

45+
from pandas.core.dtypes.base import ExtensionDtype
4546
from pandas.core.dtypes.cast import find_common_type
4647
from pandas.core.dtypes.common import (
4748
ensure_float64,
@@ -79,7 +80,10 @@
7980
Series,
8081
)
8182
import pandas.core.algorithms as algos
82-
from pandas.core.arrays import ExtensionArray
83+
from pandas.core.arrays import (
84+
BaseMaskedArray,
85+
ExtensionArray,
86+
)
8387
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
8488
import pandas.core.common as com
8589
from pandas.core.construction import extract_array
@@ -92,6 +96,24 @@
9296
from pandas.core import groupby
9397
from pandas.core.arrays import DatetimeArray
9498

99+
_factorizers = {
100+
np.int64: libhashtable.Int64Factorizer,
101+
np.longlong: libhashtable.Int64Factorizer,
102+
np.int32: libhashtable.Int32Factorizer,
103+
np.int16: libhashtable.Int16Factorizer,
104+
np.int8: libhashtable.Int8Factorizer,
105+
np.uint64: libhashtable.UInt64Factorizer,
106+
np.uint32: libhashtable.UInt32Factorizer,
107+
np.uint16: libhashtable.UInt16Factorizer,
108+
np.uint8: libhashtable.UInt8Factorizer,
109+
np.bool_: libhashtable.UInt8Factorizer,
110+
np.float64: libhashtable.Float64Factorizer,
111+
np.float32: libhashtable.Float32Factorizer,
112+
np.complex64: libhashtable.Complex64Factorizer,
113+
np.complex128: libhashtable.Complex128Factorizer,
114+
np.object_: libhashtable.ObjectFactorizer,
115+
}
116+
95117

96118
@Substitution("\nleft : DataFrame or named Series")
97119
@Appender(_merge_doc, indents=0)
@@ -2335,25 +2357,37 @@ def _factorize_keys(
23352357
rk = ensure_int64(rk.codes)
23362358

23372359
elif isinstance(lk, ExtensionArray) and is_dtype_equal(lk.dtype, rk.dtype):
2338-
lk, _ = lk._values_for_factorize()
2339-
2340-
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
2341-
# "_values_for_factorize"
2342-
rk, _ = rk._values_for_factorize() # type: ignore[union-attr]
2343-
2344-
klass: type[libhashtable.Factorizer] | type[libhashtable.Int64Factorizer]
2345-
if is_integer_dtype(lk.dtype) and is_integer_dtype(rk.dtype):
2346-
# GH#23917 TODO: needs tests for case where lk is integer-dtype
2347-
# and rk is datetime-dtype
2348-
klass = libhashtable.Int64Factorizer
2349-
lk = ensure_int64(np.asarray(lk))
2350-
rk = ensure_int64(np.asarray(rk))
2351-
2352-
elif needs_i8_conversion(lk.dtype) and is_dtype_equal(lk.dtype, rk.dtype):
2353-
# GH#23917 TODO: Needs tests for non-matching dtypes
2354-
klass = libhashtable.Int64Factorizer
2355-
lk = ensure_int64(np.asarray(lk, dtype=np.int64))
2356-
rk = ensure_int64(np.asarray(rk, dtype=np.int64))
2360+
if not isinstance(lk, BaseMaskedArray):
2361+
lk, _ = lk._values_for_factorize()
2362+
2363+
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
2364+
# "_values_for_factorize"
2365+
rk, _ = rk._values_for_factorize() # type: ignore[union-attr]
2366+
2367+
klass: type[libhashtable.Factorizer]
2368+
if is_numeric_dtype(lk.dtype):
2369+
if not is_dtype_equal(lk, rk):
2370+
dtype = find_common_type([lk.dtype, rk.dtype])
2371+
if isinstance(dtype, ExtensionDtype):
2372+
cls = dtype.construct_array_type()
2373+
if not isinstance(lk, ExtensionArray):
2374+
lk = cls._from_sequence(lk, dtype=dtype, copy=False)
2375+
else:
2376+
lk = lk.astype(dtype)
2377+
2378+
if not isinstance(rk, ExtensionArray):
2379+
rk = cls._from_sequence(rk, dtype=dtype, copy=False)
2380+
else:
2381+
rk = rk.astype(dtype)
2382+
else:
2383+
lk = lk.astype(dtype)
2384+
rk = rk.astype(dtype)
2385+
if isinstance(lk, BaseMaskedArray):
2386+
# Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]";
2387+
# expected type "Type[object]"
2388+
klass = _factorizers[lk.dtype.type] # type: ignore[index]
2389+
else:
2390+
klass = _factorizers[lk.dtype.type]
23572391

23582392
else:
23592393
klass = libhashtable.ObjectFactorizer
@@ -2362,8 +2396,16 @@ def _factorize_keys(
23622396

23632397
rizer = klass(max(len(lk), len(rk)))
23642398

2365-
llab = rizer.factorize(lk)
2366-
rlab = rizer.factorize(rk)
2399+
if isinstance(lk, BaseMaskedArray):
2400+
assert isinstance(rk, BaseMaskedArray)
2401+
llab = rizer.factorize(lk._data, mask=lk._mask)
2402+
rlab = rizer.factorize(rk._data, mask=rk._mask)
2403+
else:
2404+
# Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
2405+
# "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
2406+
# ndarray[Any, dtype[object_]]]"; expected "ndarray[Any, dtype[object_]]"
2407+
llab = rizer.factorize(lk) # type: ignore[arg-type]
2408+
rlab = rizer.factorize(rk) # type: ignore[arg-type]
23672409
assert llab.dtype == np.dtype(np.intp), llab.dtype
23682410
assert rlab.dtype == np.dtype(np.intp), rlab.dtype
23692411

pandas/tests/reshape/merge/test_merge.py

+24
Original file line numberDiff line numberDiff line change
@@ -2714,3 +2714,27 @@ def test_merge_different_index_names():
27142714
result = merge(left, right, left_on="c", right_on="d")
27152715
expected = DataFrame({"a_x": [1], "a_y": 1})
27162716
tm.assert_frame_equal(result, expected)
2717+
2718+
2719+
def test_merge_ea(any_numeric_ea_dtype, join_type):
2720+
# GH#44240
2721+
left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype)
2722+
right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype)
2723+
result = left.merge(right, how=join_type)
2724+
expected = DataFrame({"a": [1, 2, 3], "b": 1, "c": 2}, dtype=any_numeric_ea_dtype)
2725+
tm.assert_frame_equal(result, expected)
2726+
2727+
2728+
def test_merge_ea_and_non_ea(any_numeric_ea_dtype, join_type):
2729+
# GH#44240
2730+
left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype)
2731+
right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype.lower())
2732+
result = left.merge(right, how=join_type)
2733+
expected = DataFrame(
2734+
{
2735+
"a": Series([1, 2, 3], dtype=any_numeric_ea_dtype),
2736+
"b": Series([1, 1, 1], dtype=any_numeric_ea_dtype),
2737+
"c": Series([2, 2, 2], dtype=any_numeric_ea_dtype.lower()),
2738+
}
2739+
)
2740+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)