diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 0122c84ba2a8e..4fda2cd11ce12 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -20,7 +20,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ -- +- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`) .. --------------------------------------------------------------------------- .. _whatsnew_201.other: diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index bfaf403491801..6b8e72c503de3 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -80,6 +80,7 @@ ) from pandas import ( + ArrowDtype, Categorical, Index, MultiIndex, @@ -87,6 +88,7 @@ ) import pandas.core.algorithms as algos from pandas.core.arrays import ( + ArrowExtensionArray, BaseMaskedArray, ExtensionArray, ) @@ -2377,7 +2379,11 @@ def _factorize_keys( rk = ensure_int64(rk.codes) elif isinstance(lk, ExtensionArray) and is_dtype_equal(lk.dtype, rk.dtype): - if not isinstance(lk, BaseMaskedArray): + if not isinstance(lk, BaseMaskedArray) and not ( + # exclude arrow dtypes that would get cast to object + isinstance(lk.dtype, ArrowDtype) + and is_numeric_dtype(lk.dtype.numpy_dtype) + ): lk, _ = lk._values_for_factorize() # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute @@ -2392,6 +2398,16 @@ def _factorize_keys( assert isinstance(rk, BaseMaskedArray) llab = rizer.factorize(lk._data, mask=lk._mask) rlab = rizer.factorize(rk._data, mask=rk._mask) + elif isinstance(lk, ArrowExtensionArray): + assert isinstance(rk, ArrowExtensionArray) + # we can only get here with numeric dtypes + # TODO: Remove when we have a Factorizer for Arrow + llab = rizer.factorize( + lk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=lk.isna() + ) + rlab = rizer.factorize( + rk.to_numpy(na_value=1, dtype=lk.dtype.numpy_dtype), mask=rk.isna() + ) else: # Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type # "Union[ndarray[Any, dtype[signedinteger[_64Bit]]], @@ -2450,6 +2466,8 @@ def _convert_arrays_and_get_rizer_klass( # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]"; # expected type "Type[object]" klass = _factorizers[lk.dtype.type] # type: ignore[index] + elif isinstance(lk.dtype, ArrowDtype): + klass = _factorizers[lk.dtype.numpy_dtype.type] else: klass = _factorizers[lk.dtype.type] diff --git a/pandas/tests/reshape/merge/test_merge.py b/pandas/tests/reshape/merge/test_merge.py index 6f2b327c37067..3a822c8134eb4 100644 --- a/pandas/tests/reshape/merge/test_merge.py +++ b/pandas/tests/reshape/merge/test_merge.py @@ -2761,3 +2761,18 @@ def test_merge_ea_and_non_ea(any_numeric_ea_dtype, join_type): } ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["int64", "int64[pyarrow]"]) +def test_merge_arrow_and_numpy_dtypes(dtype): + # GH#52406 + pytest.importorskip("pyarrow") + df = DataFrame({"a": [1, 2]}, dtype=dtype) + df2 = DataFrame({"a": [1, 2]}, dtype="int64[pyarrow]") + result = df.merge(df2) + expected = df.copy() + tm.assert_frame_equal(result, expected) + + result = df2.merge(df) + expected = df2.copy() + tm.assert_frame_equal(result, expected)