Skip to content

Commit 2f4c93e

Browse files
authored
BUG/PERF: merge_asof raising TypeError for various "by" column dtypes (#55678)
* factorize by keys for merge_asof * use intp_t * add test * whatsnew * ensure int64
1 parent de50590 commit 2f4c93e

File tree

5 files changed

+46
-85
lines changed

5 files changed

+46
-85
lines changed

doc/source/whatsnew/v2.2.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ Other Deprecations
296296
Performance improvements
297297
~~~~~~~~~~~~~~~~~~~~~~~~
298298
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
299-
- Performance improvement in :func:`merge_asof` when ``by`` contains more than one key (:issue:`55580`)
299+
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
300300
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
301301
- Performance improvement in :func:`to_dict` on converting DataFrame to dictionary (:issue:`50990`)
302302
- Performance improvement in :meth:`DataFrame.groupby` when aggregating pyarrow timestamp and duration dtypes (:issue:`55031`)
@@ -411,6 +411,7 @@ Groupby/resample/rolling
411411
Reshaping
412412
^^^^^^^^^
413413
- Bug in :func:`concat` ignoring ``sort`` parameter when passed :class:`DatetimeIndex` indexes (:issue:`54769`)
414+
- Bug in :func:`merge_asof` raising ``TypeError`` when ``by`` dtype is not ``object``, ``int64``, or ``uint64`` (:issue:`22794`)
414415
- Bug in :func:`merge` returning columns in incorrect order when left and/or right is empty (:issue:`51929`)
415416
- Bug in :meth:`pandas.DataFrame.melt` where it would not preserve the datetime (:issue:`55254`)
416417
-

pandas/_libs/join.pyi

+6-6
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,26 @@ def outer_join_indexer(
5353
def asof_join_backward_on_X_by_Y(
5454
left_values: np.ndarray, # ndarray[numeric_t]
5555
right_values: np.ndarray, # ndarray[numeric_t]
56-
left_by_values: np.ndarray, # ndarray[by_t]
57-
right_by_values: np.ndarray, # ndarray[by_t]
56+
left_by_values: np.ndarray, # const int64_t[:]
57+
right_by_values: np.ndarray, # const int64_t[:]
5858
allow_exact_matches: bool = ...,
5959
tolerance: np.number | float | None = ...,
6060
use_hashtable: bool = ...,
6161
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ...
6262
def asof_join_forward_on_X_by_Y(
6363
left_values: np.ndarray, # ndarray[numeric_t]
6464
right_values: np.ndarray, # ndarray[numeric_t]
65-
left_by_values: np.ndarray, # ndarray[by_t]
66-
right_by_values: np.ndarray, # ndarray[by_t]
65+
left_by_values: np.ndarray, # const int64_t[:]
66+
right_by_values: np.ndarray, # const int64_t[:]
6767
allow_exact_matches: bool = ...,
6868
tolerance: np.number | float | None = ...,
6969
use_hashtable: bool = ...,
7070
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ...
7171
def asof_join_nearest_on_X_by_Y(
7272
left_values: np.ndarray, # ndarray[numeric_t]
7373
right_values: np.ndarray, # ndarray[numeric_t]
74-
left_by_values: np.ndarray, # ndarray[by_t]
75-
right_by_values: np.ndarray, # ndarray[by_t]
74+
left_by_values: np.ndarray, # const int64_t[:]
75+
right_by_values: np.ndarray, # const int64_t[:]
7676
allow_exact_matches: bool = ...,
7777
tolerance: np.number | float | None = ...,
7878
use_hashtable: bool = ...,

pandas/_libs/join.pyx

+11-34
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ from numpy cimport (
77
int64_t,
88
intp_t,
99
ndarray,
10-
uint64_t,
1110
)
1211

1312
cnp.import_array()
@@ -679,23 +678,13 @@ def outer_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t]
679678
# asof_join_by
680679
# ----------------------------------------------------------------------
681680

682-
from pandas._libs.hashtable cimport (
683-
HashTable,
684-
Int64HashTable,
685-
PyObjectHashTable,
686-
UInt64HashTable,
687-
)
688-
689-
ctypedef fused by_t:
690-
object
691-
int64_t
692-
uint64_t
681+
from pandas._libs.hashtable cimport Int64HashTable
693682

694683

695684
def asof_join_backward_on_X_by_Y(ndarray[numeric_t] left_values,
696685
ndarray[numeric_t] right_values,
697-
ndarray[by_t] left_by_values,
698-
ndarray[by_t] right_by_values,
686+
const int64_t[:] left_by_values,
687+
const int64_t[:] right_by_values,
699688
bint allow_exact_matches=True,
700689
tolerance=None,
701690
bint use_hashtable=True):
@@ -706,8 +695,7 @@ def asof_join_backward_on_X_by_Y(ndarray[numeric_t] left_values,
706695
bint has_tolerance = False
707696
numeric_t tolerance_ = 0
708697
numeric_t diff = 0
709-
HashTable hash_table
710-
by_t by_value
698+
Int64HashTable hash_table
711699

712700
# if we are using tolerance, set our objects
713701
if tolerance is not None:
@@ -721,12 +709,7 @@ def asof_join_backward_on_X_by_Y(ndarray[numeric_t] left_values,
721709
right_indexer = np.empty(left_size, dtype=np.intp)
722710

723711
if use_hashtable:
724-
if by_t is object:
725-
hash_table = PyObjectHashTable(right_size)
726-
elif by_t is int64_t:
727-
hash_table = Int64HashTable(right_size)
728-
elif by_t is uint64_t:
729-
hash_table = UInt64HashTable(right_size)
712+
hash_table = Int64HashTable(right_size)
730713

731714
right_pos = 0
732715
for left_pos in range(left_size):
@@ -771,8 +754,8 @@ def asof_join_backward_on_X_by_Y(ndarray[numeric_t] left_values,
771754

772755
def asof_join_forward_on_X_by_Y(ndarray[numeric_t] left_values,
773756
ndarray[numeric_t] right_values,
774-
ndarray[by_t] left_by_values,
775-
ndarray[by_t] right_by_values,
757+
const int64_t[:] left_by_values,
758+
const int64_t[:] right_by_values,
776759
bint allow_exact_matches=1,
777760
tolerance=None,
778761
bint use_hashtable=True):
@@ -783,8 +766,7 @@ def asof_join_forward_on_X_by_Y(ndarray[numeric_t] left_values,
783766
bint has_tolerance = False
784767
numeric_t tolerance_ = 0
785768
numeric_t diff = 0
786-
HashTable hash_table
787-
by_t by_value
769+
Int64HashTable hash_table
788770

789771
# if we are using tolerance, set our objects
790772
if tolerance is not None:
@@ -798,12 +780,7 @@ def asof_join_forward_on_X_by_Y(ndarray[numeric_t] left_values,
798780
right_indexer = np.empty(left_size, dtype=np.intp)
799781

800782
if use_hashtable:
801-
if by_t is object:
802-
hash_table = PyObjectHashTable(right_size)
803-
elif by_t is int64_t:
804-
hash_table = Int64HashTable(right_size)
805-
elif by_t is uint64_t:
806-
hash_table = UInt64HashTable(right_size)
783+
hash_table = Int64HashTable(right_size)
807784

808785
right_pos = right_size - 1
809786
for left_pos in range(left_size - 1, -1, -1):
@@ -849,8 +826,8 @@ def asof_join_forward_on_X_by_Y(ndarray[numeric_t] left_values,
849826

850827
def asof_join_nearest_on_X_by_Y(ndarray[numeric_t] left_values,
851828
ndarray[numeric_t] right_values,
852-
ndarray[by_t] left_by_values,
853-
ndarray[by_t] right_by_values,
829+
const int64_t[:] left_by_values,
830+
const int64_t[:] right_by_values,
854831
bint allow_exact_matches=True,
855832
tolerance=None,
856833
bint use_hashtable=True):

pandas/core/reshape/merge.py

+24-41
Original file line numberDiff line numberDiff line change
@@ -2153,54 +2153,37 @@ def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]
21532153
if self.left_by is not None:
21542154
# remove 'on' parameter from values if one existed
21552155
if self.left_index and self.right_index:
2156-
left_by_values = self.left_join_keys
2157-
right_by_values = self.right_join_keys
2156+
left_join_keys = self.left_join_keys
2157+
right_join_keys = self.right_join_keys
21582158
else:
2159-
left_by_values = self.left_join_keys[0:-1]
2160-
right_by_values = self.right_join_keys[0:-1]
2161-
2162-
# get tuple representation of values if more than one
2163-
if len(left_by_values) == 1:
2164-
lbv = left_by_values[0]
2165-
rbv = right_by_values[0]
2166-
2167-
# TODO: conversions for EAs that can be no-copy.
2168-
lbv = np.asarray(lbv)
2169-
rbv = np.asarray(rbv)
2170-
if needs_i8_conversion(lbv.dtype):
2171-
lbv = lbv.view("i8")
2172-
if needs_i8_conversion(rbv.dtype):
2173-
rbv = rbv.view("i8")
2159+
left_join_keys = self.left_join_keys[0:-1]
2160+
right_join_keys = self.right_join_keys[0:-1]
2161+
2162+
mapped = [
2163+
_factorize_keys(
2164+
left_join_keys[n],
2165+
right_join_keys[n],
2166+
sort=False,
2167+
how="left",
2168+
)
2169+
for n in range(len(left_join_keys))
2170+
]
2171+
2172+
if len(left_join_keys) == 1:
2173+
left_by_values = mapped[0][0]
2174+
right_by_values = mapped[0][1]
21742175
else:
2175-
# We get here with non-ndarrays in test_merge_by_col_tz_aware
2176-
# and test_merge_groupby_multiple_column_with_categorical_column
2177-
mapped = [
2178-
_factorize_keys(
2179-
left_by_values[n],
2180-
right_by_values[n],
2181-
sort=False,
2182-
how="left",
2183-
)
2184-
for n in range(len(left_by_values))
2185-
]
21862176
arrs = [np.concatenate(m[:2]) for m in mapped]
21872177
shape = tuple(m[2] for m in mapped)
21882178
group_index = get_group_index(
21892179
arrs, shape=shape, sort=False, xnull=False
21902180
)
2191-
left_len = len(left_by_values[0])
2192-
lbv = group_index[:left_len]
2193-
rbv = group_index[left_len:]
2194-
# error: Incompatible types in assignment (expression has type
2195-
# "Union[ndarray[Any, dtype[Any]], ndarray[Any, dtype[object_]]]",
2196-
# variable has type "List[Union[Union[ExtensionArray,
2197-
# ndarray[Any, Any]], Index, Series]]")
2198-
right_by_values = rbv # type: ignore[assignment]
2199-
# error: Incompatible types in assignment (expression has type
2200-
# "Union[ndarray[Any, dtype[Any]], ndarray[Any, dtype[object_]]]",
2201-
# variable has type "List[Union[Union[ExtensionArray,
2202-
# ndarray[Any, Any]], Index, Series]]")
2203-
left_by_values = lbv # type: ignore[assignment]
2181+
left_len = len(left_join_keys[0])
2182+
left_by_values = group_index[:left_len]
2183+
right_by_values = group_index[left_len:]
2184+
2185+
left_by_values = ensure_int64(left_by_values)
2186+
right_by_values = ensure_int64(right_by_values)
22042187

22052188
# choose appropriate function by type
22062189
func = _asof_by_function(self.direction)

pandas/tests/reshape/merge/test_merge_asof.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1345,9 +1345,9 @@ def test_by_mixed_tz_aware(self):
13451345
expected["value_y"] = np.array([np.nan], dtype=object)
13461346
tm.assert_frame_equal(result, expected)
13471347

1348-
@pytest.mark.parametrize("dtype", ["m8[ns]", "M8[us]"])
1349-
def test_by_datelike(self, dtype):
1350-
# GH 55453
1348+
@pytest.mark.parametrize("dtype", ["float64", "int16", "m8[ns]", "M8[us]"])
1349+
def test_by_dtype(self, dtype):
1350+
# GH 55453, GH 22794
13511351
left = pd.DataFrame(
13521352
{
13531353
"by_col": np.array([1], dtype=dtype),

0 commit comments

Comments
 (0)