Skip to content

Commit 857bf37

Browse files
authored
ENH: Add fast array equal function for indexers (#50592)
1 parent fa78ea8 commit 857bf37

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

pandas/_libs/lib.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,6 @@ def get_reverse_indexer(
240240
) -> npt.NDArray[np.intp]: ...
241241
def is_bool_list(obj: list) -> bool: ...
242242
def dtypes_all_equal(types: list[DtypeObj]) -> bool: ...
243+
def array_equal_fast(
244+
left: np.ndarray, right: np.ndarray # np.ndarray[np.int64, ndim=1]
245+
) -> bool: ...

pandas/_libs/lib.pyx

+29
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ from numpy cimport (
5050
complex128_t,
5151
flatiter,
5252
float64_t,
53+
int32_t,
5354
int64_t,
5455
intp_t,
5556
ndarray,
@@ -642,6 +643,34 @@ def array_equivalent_object(ndarray left, ndarray right) -> bool:
642643
return True
643644

644645

646+
ctypedef fused int6432_t:
647+
int64_t
648+
int32_t
649+
650+
651+
@cython.wraparound(False)
652+
@cython.boundscheck(False)
653+
def array_equal_fast(
654+
ndarray[int6432_t, ndim=1] left, ndarray[int6432_t, ndim=1] right,
655+
) -> bool:
656+
"""
657+
Perform an element by element comparison on 1-d integer arrays, meant for indexer
658+
comparisons
659+
"""
660+
cdef:
661+
Py_ssize_t i, n = left.size
662+
663+
if left.size != right.size:
664+
return False
665+
666+
for i in range(n):
667+
668+
if left[i] != right[i]:
669+
return False
670+
671+
return True
672+
673+
645674
ctypedef fused ndarr_object:
646675
ndarray[object, ndim=1]
647676
ndarray[object, ndim=2]

pandas/tests/libs/test_lib.py

+21
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,27 @@ def test_get_reverse_indexer(self):
243243
expected = np.array([4, 2, 3, 6, 7], dtype=np.intp)
244244
tm.assert_numpy_array_equal(result, expected)
245245

246+
@pytest.mark.parametrize("dtype", ["int64", "int32"])
247+
def test_array_equal_fast(self, dtype):
248+
# GH#50592
249+
left = np.arange(1, 100, dtype=dtype)
250+
right = np.arange(1, 100, dtype=dtype)
251+
assert lib.array_equal_fast(left, right)
252+
253+
@pytest.mark.parametrize("dtype", ["int64", "int32"])
254+
def test_array_equal_fast_not_equal(self, dtype):
255+
# GH#50592
256+
left = np.array([1, 2], dtype=dtype)
257+
right = np.array([2, 2], dtype=dtype)
258+
assert not lib.array_equal_fast(left, right)
259+
260+
@pytest.mark.parametrize("dtype", ["int64", "int32"])
261+
def test_array_equal_fast_not_equal_shape(self, dtype):
262+
# GH#50592
263+
left = np.array([1, 2, 3], dtype=dtype)
264+
right = np.array([2, 2], dtype=dtype)
265+
assert not lib.array_equal_fast(left, right)
266+
246267

247268
def test_cache_readonly_preserve_docstrings():
248269
# GH18197

0 commit comments

Comments
 (0)