diff --git a/asv_bench/benchmarks/join_merge.py b/asv_bench/benchmarks/join_merge.py index d9fb3c8a8ff89..fdbf325dcf997 100644 --- a/asv_bench/benchmarks/join_merge.py +++ b/asv_bench/benchmarks/join_merge.py @@ -273,6 +273,38 @@ def time_merge_dataframes_cross(self, sort): merge(self.left.loc[:2000], self.right.loc[:2000], how="cross", sort=sort) +class MergeEA: + + params = [ + "Int64", + "Int32", + "Int16", + "UInt64", + "UInt32", + "UInt16", + "Float64", + "Float32", + ] + param_names = ["dtype"] + + def setup(self, dtype): + N = 10_000 + indices = np.arange(1, N) + key = np.tile(indices[:8000], 10) + self.left = DataFrame( + {"key": Series(key, dtype=dtype), "value": np.random.randn(80000)} + ) + self.right = DataFrame( + { + "key": Series(indices[2000:], dtype=dtype), + "value2": np.random.randn(7999), + } + ) + + def time_merge(self, dtype): + merge(self.left, self.right) + + class I8Merge: params = ["inner", "outer", "left", "right"] diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 42170aaa09978..28197c2178683 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -57,6 +57,7 @@ Other enhancements - :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`) - Added new argument ``use_nullable_dtypes`` to :func:`read_csv` and :func:`read_excel` to enable automatic conversion to nullable dtypes (:issue:`36712`) - Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`) +- Added support for extension array dtypes in :func:`merge` (:issue:`44240`) - Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`) - :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`) - Fix ``test`` optional_extra by adding missing test package ``pytest-asyncio`` (:issue:`48361`) diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index eb0b46101c2d8..af47e6c408c05 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -1,4 +1,5 @@ from typing import ( + Any, Hashable, Literal, ) @@ -13,6 +14,7 @@ def unique_label_indices( class Factorizer: count: int + uniques: Any def __init__(self, size_hint: int) -> None: ... def get_count(self) -> int: ... def factorize( @@ -21,6 +23,7 @@ class Factorizer: sort: bool = ..., na_sentinel=..., na_value=..., + mask=..., ) -> npt.NDArray[np.intp]: ... class ObjectFactorizer(Factorizer): diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 9e73dcb789075..e818e367ca83d 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -42,6 +42,7 @@ ) from pandas.util._exceptions import find_stack_level +from pandas.core.dtypes.base import ExtensionDtype from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( ensure_float64, @@ -79,7 +80,10 @@ Series, ) import pandas.core.algorithms as algos -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import ( + BaseMaskedArray, + ExtensionArray, +) from pandas.core.arrays._mixins import NDArrayBackedExtensionArray import pandas.core.common as com from pandas.core.construction import extract_array @@ -92,6 +96,24 @@ from pandas.core import groupby from pandas.core.arrays import DatetimeArray +_factorizers = { + np.int64: libhashtable.Int64Factorizer, + np.longlong: libhashtable.Int64Factorizer, + np.int32: libhashtable.Int32Factorizer, + np.int16: libhashtable.Int16Factorizer, + np.int8: libhashtable.Int8Factorizer, + np.uint64: libhashtable.UInt64Factorizer, + np.uint32: libhashtable.UInt32Factorizer, + np.uint16: libhashtable.UInt16Factorizer, + np.uint8: libhashtable.UInt8Factorizer, + np.bool_: libhashtable.UInt8Factorizer, + np.float64: libhashtable.Float64Factorizer, + np.float32: libhashtable.Float32Factorizer, + np.complex64: libhashtable.Complex64Factorizer, + np.complex128: libhashtable.Complex128Factorizer, + np.object_: libhashtable.ObjectFactorizer, +} + @Substitution("\nleft : DataFrame or named Series") @Appender(_merge_doc, indents=0) @@ -2335,25 +2357,37 @@ def _factorize_keys( rk = ensure_int64(rk.codes) elif isinstance(lk, ExtensionArray) and is_dtype_equal(lk.dtype, rk.dtype): - lk, _ = lk._values_for_factorize() - - # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute - # "_values_for_factorize" - rk, _ = rk._values_for_factorize() # type: ignore[union-attr] - - klass: type[libhashtable.Factorizer] | type[libhashtable.Int64Factorizer] - if is_integer_dtype(lk.dtype) and is_integer_dtype(rk.dtype): - # GH#23917 TODO: needs tests for case where lk is integer-dtype - # and rk is datetime-dtype - klass = libhashtable.Int64Factorizer - lk = ensure_int64(np.asarray(lk)) - rk = ensure_int64(np.asarray(rk)) - - elif needs_i8_conversion(lk.dtype) and is_dtype_equal(lk.dtype, rk.dtype): - # GH#23917 TODO: Needs tests for non-matching dtypes - klass = libhashtable.Int64Factorizer - lk = ensure_int64(np.asarray(lk, dtype=np.int64)) - rk = ensure_int64(np.asarray(rk, dtype=np.int64)) + if not isinstance(lk, BaseMaskedArray): + lk, _ = lk._values_for_factorize() + + # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute + # "_values_for_factorize" + rk, _ = rk._values_for_factorize() # type: ignore[union-attr] + + klass: type[libhashtable.Factorizer] + if is_numeric_dtype(lk.dtype): + if not is_dtype_equal(lk, rk): + dtype = find_common_type([lk.dtype, rk.dtype]) + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + if not isinstance(lk, ExtensionArray): + lk = cls._from_sequence(lk, dtype=dtype, copy=False) + else: + lk = lk.astype(dtype) + + if not isinstance(rk, ExtensionArray): + rk = cls._from_sequence(rk, dtype=dtype, copy=False) + else: + rk = rk.astype(dtype) + else: + lk = lk.astype(dtype) + rk = rk.astype(dtype) + if isinstance(lk, BaseMaskedArray): + # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]"; + # expected type "Type[object]" + klass = _factorizers[lk.dtype.type] # type: ignore[index] + else: + klass = _factorizers[lk.dtype.type] else: klass = libhashtable.ObjectFactorizer @@ -2362,8 +2396,16 @@ def _factorize_keys( rizer = klass(max(len(lk), len(rk))) - llab = rizer.factorize(lk) - rlab = rizer.factorize(rk) + if isinstance(lk, BaseMaskedArray): + assert isinstance(rk, BaseMaskedArray) + llab = rizer.factorize(lk._data, mask=lk._mask) + rlab = rizer.factorize(rk._data, mask=rk._mask) + else: + # Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type + # "Union[ndarray[Any, dtype[signedinteger[_64Bit]]], + # ndarray[Any, dtype[object_]]]"; expected "ndarray[Any, dtype[object_]]" + llab = rizer.factorize(lk) # type: ignore[arg-type] + rlab = rizer.factorize(rk) # type: ignore[arg-type] assert llab.dtype == np.dtype(np.intp), llab.dtype assert rlab.dtype == np.dtype(np.intp), rlab.dtype diff --git a/pandas/tests/reshape/merge/test_merge.py b/pandas/tests/reshape/merge/test_merge.py index e4638c43e5a66..946e7e48148b4 100644 --- a/pandas/tests/reshape/merge/test_merge.py +++ b/pandas/tests/reshape/merge/test_merge.py @@ -2714,3 +2714,27 @@ def test_merge_different_index_names(): result = merge(left, right, left_on="c", right_on="d") expected = DataFrame({"a_x": [1], "a_y": 1}) tm.assert_frame_equal(result, expected) + + +def test_merge_ea(any_numeric_ea_dtype, join_type): + # GH#44240 + left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype) + right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype) + result = left.merge(right, how=join_type) + expected = DataFrame({"a": [1, 2, 3], "b": 1, "c": 2}, dtype=any_numeric_ea_dtype) + tm.assert_frame_equal(result, expected) + + +def test_merge_ea_and_non_ea(any_numeric_ea_dtype, join_type): + # GH#44240 + left = DataFrame({"a": [1, 2, 3], "b": 1}, dtype=any_numeric_ea_dtype) + right = DataFrame({"a": [1, 2, 3], "c": 2}, dtype=any_numeric_ea_dtype.lower()) + result = left.merge(right, how=join_type) + expected = DataFrame( + { + "a": Series([1, 2, 3], dtype=any_numeric_ea_dtype), + "b": Series([1, 1, 1], dtype=any_numeric_ea_dtype), + "c": Series([2, 2, 2], dtype=any_numeric_ea_dtype.lower()), + } + ) + tm.assert_frame_equal(result, expected)