Skip to content

Commit 1fd31bd

Browse files
authored
REF: merge_asof dont use values_for_argsort (#45475)
1 parent 011f116 commit 1fd31bd

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

pandas/core/reshape/merge.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from pandas.core import groupby
7777
import pandas.core.algorithms as algos
7878
from pandas.core.arrays import ExtensionArray
79+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
7980
import pandas.core.common as com
8081
from pandas.core.construction import extract_array
8182
from pandas.core.frame import _merge_doc
@@ -1906,14 +1907,27 @@ def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]
19061907

19071908
def flip(xs) -> np.ndarray:
19081909
"""unlike np.transpose, this returns an array of tuples"""
1909-
# error: Item "ndarray" of "Union[Any, Union[ExtensionArray, ndarray]]" has
1910-
# no attribute "_values_for_argsort"
1911-
xs = [
1912-
x
1913-
if not is_extension_array_dtype(x)
1914-
else extract_array(x)._values_for_argsort() # type: ignore[union-attr]
1915-
for x in xs
1916-
]
1910+
1911+
def injection(obj):
1912+
if not is_extension_array_dtype(obj):
1913+
# ndarray
1914+
return obj
1915+
obj = extract_array(obj)
1916+
if isinstance(obj, NDArrayBackedExtensionArray):
1917+
# fastpath for e.g. dt64tz, categorical
1918+
return obj._ndarray
1919+
# FIXME: returning obj._values_for_argsort() here doesn't
1920+
# break in any existing test cases, but i (@jbrockmendel)
1921+
# am pretty sure it should!
1922+
# e.g.
1923+
# arr = pd.array([0, pd.NA, 255], dtype="UInt8")
1924+
# will have values_for_argsort (before GH#45434)
1925+
# np.array([0, 255, 255], dtype=np.uint8)
1926+
# and the non-injectivity should make a difference somehow
1927+
# shouldn't it?
1928+
return np.asarray(obj)
1929+
1930+
xs = [injection(x) for x in xs]
19171931
labels = list(string.ascii_lowercase[: len(xs)])
19181932
dtypes = [x.dtype for x in xs]
19191933
labeled_dtypes = list(zip(labels, dtypes))
@@ -1966,6 +1980,8 @@ def flip(xs) -> np.ndarray:
19661980
left_by_values = left_by_values[0]
19671981
right_by_values = right_by_values[0]
19681982
else:
1983+
# We get here with non-ndarrays in test_merge_by_col_tz_aware
1984+
# and test_merge_groupby_multiple_column_with_categorical_column
19691985
left_by_values = flip(left_by_values)
19701986
right_by_values = flip(right_by_values)
19711987

pandas/tests/reshape/merge/test_merge_asof.py

+45
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,50 @@ def test_merge_on_nans(self, func, side):
12241224
else:
12251225
merge_asof(df, df_null, on="a")
12261226

1227+
def test_by_nullable(self, any_numeric_ea_dtype):
1228+
# Note: this test passes if instead of using pd.array we use
1229+
# np.array([np.nan, 1]). Other than that, I (@jbrockmendel)
1230+
# have NO IDEA what the expected behavior is.
1231+
# TODO(GH#32306): may be relevant to the expected behavior here.
1232+
1233+
arr = pd.array([pd.NA, 0, 1], dtype=any_numeric_ea_dtype)
1234+
if arr.dtype.kind in ["i", "u"]:
1235+
max_val = np.iinfo(arr.dtype.numpy_dtype).max
1236+
else:
1237+
max_val = np.finfo(arr.dtype.numpy_dtype).max
1238+
# set value s.t. (at least for integer dtypes) arr._values_for_argsort
1239+
# is not an injection
1240+
arr[2] = max_val
1241+
1242+
left = pd.DataFrame(
1243+
{
1244+
"by_col1": arr,
1245+
"by_col2": ["HELLO", "To", "You"],
1246+
"on_col": [2, 4, 6],
1247+
"value": ["a", "c", "e"],
1248+
}
1249+
)
1250+
right = pd.DataFrame(
1251+
{
1252+
"by_col1": arr,
1253+
"by_col2": ["WORLD", "Wide", "Web"],
1254+
"on_col": [1, 2, 6],
1255+
"value": ["b", "d", "f"],
1256+
}
1257+
)
1258+
1259+
result = merge_asof(left, right, by=["by_col1", "by_col2"], on="on_col")
1260+
expected = pd.DataFrame(
1261+
{
1262+
"by_col1": arr,
1263+
"by_col2": ["HELLO", "To", "You"],
1264+
"on_col": [2, 4, 6],
1265+
"value_x": ["a", "c", "e"],
1266+
}
1267+
)
1268+
expected["value_y"] = np.array([np.nan, np.nan, np.nan], dtype=object)
1269+
tm.assert_frame_equal(result, expected)
1270+
12271271
def test_merge_by_col_tz_aware(self):
12281272
# GH 21184
12291273
left = pd.DataFrame(
@@ -1309,6 +1353,7 @@ def test_timedelta_tolerance_nearest(self):
13091353

13101354
tm.assert_frame_equal(result, expected)
13111355

1356+
# TODO: any_int_dtype; causes failures in _get_join_indexers
13121357
def test_int_type_tolerance(self, any_int_numpy_dtype):
13131358
# GH #28870
13141359

0 commit comments

Comments
 (0)