Skip to content

Commit dcd17bf

Browse files
jorisvandenbosschewesm
authored andcommitted
ARROW-9445: [Python] Revert Array.equals changes + expose comparison ops in compute
Closes #7737 from jorisvandenbossche/ARROW-9445 Authored-by: Joris Van den Bossche <[email protected]> Signed-off-by: Wes McKinney <[email protected]>
1 parent 389b153 commit dcd17bf

File tree

7 files changed

+54
-47
lines changed

7 files changed

+54
-47
lines changed

python/pyarrow/array.pxi

+6-25
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,6 @@
1717

1818
import warnings
1919

20-
from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE
21-
22-
23-
cdef str _op_to_function_name(int op):
24-
cdef str function_name
25-
26-
if op == Py_EQ:
27-
function_name = "equal"
28-
elif op == Py_NE:
29-
function_name = "not_equal"
30-
elif op == Py_GT:
31-
function_name = "greater"
32-
elif op == Py_GE:
33-
function_name = "greater_equal"
34-
elif op == Py_LT:
35-
function_name = "less"
36-
elif op == Py_LE:
37-
function_name = "less_equal"
38-
39-
return function_name
40-
4120

4221
cdef _sequence_to_array(object sequence, object mask, object size,
4322
DataType type, CMemoryPool* pool, c_bool from_pandas):
@@ -773,10 +752,6 @@ cdef class Array(_PandasConvertible):
773752
with nogil:
774753
check_status(DebugPrint(deref(self.ap), 0))
775754

776-
def __richcmp__(self, other, int op):
777-
function_name = _op_to_function_name(op)
778-
return _pc().call_function(function_name, [self, other])
779-
780755
def diff(self, Array other):
781756
"""
782757
Compare contents of this array against another one.
@@ -999,6 +974,12 @@ cdef class Array(_PandasConvertible):
999974
def __str__(self):
1000975
return self.to_string()
1001976

977+
def __eq__(self, other):
978+
try:
979+
return self.equals(other)
980+
except TypeError:
981+
return NotImplemented
982+
1002983
def equals(Array self, Array other):
1003984
return self.ap.Equals(deref(other.ap))
1004985

python/pyarrow/compute.py

+7
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ def func(left, right):
135135
subtract = _simple_binary_function('subtract')
136136
multiply = _simple_binary_function('multiply')
137137

138+
equal = _simple_binary_function('equal')
139+
not_equal = _simple_binary_function('not_equal')
140+
greater = _simple_binary_function('greater')
141+
greater_equal = _simple_binary_function('greater_equal')
142+
less = _simple_binary_function('less')
143+
less_equal = _simple_binary_function('less_equal')
144+
138145

139146
def binary_contains_exact(array, pattern):
140147
"""

python/pyarrow/table.pxi

+6-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@ cdef class ChunkedArray(_PandasConvertible):
4141
def __reduce__(self):
4242
return chunked_array, (self.chunks, self.type)
4343

44-
def __richcmp__(self, other, int op):
45-
function_name = _op_to_function_name(op)
46-
return _pc().call_function(function_name, [self, other])
47-
4844
@property
4945
def data(self):
5046
import warnings
@@ -189,6 +185,12 @@ cdef class ChunkedArray(_PandasConvertible):
189185
"""
190186
return _pc().is_valid(self)
191187

188+
def __eq__(self, other):
189+
try:
190+
return self.equals(other)
191+
except TypeError:
192+
return NotImplemented
193+
192194
def equals(self, ChunkedArray other):
193195
"""
194196
Return whether the contents of two chunked arrays are equal.

python/pyarrow/tests/test_array.py

+13
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,19 @@ def test_array_ref_to_ndarray_base():
484484
assert sys.getrefcount(arr) == (refcount + 1)
485485

486486

487+
def test_array_eq():
488+
# ARROW-2150 / ARROW-9445: we define the __eq__ behavior to be
489+
# data equality (not element-wise equality)
490+
arr1 = pa.array([1, 2, 3], type=pa.int32())
491+
arr2 = pa.array([1, 2, 3], type=pa.int32())
492+
arr3 = pa.array([1, 2, 3], type=pa.int64())
493+
494+
assert (arr1 == arr2) is True
495+
assert (arr1 != arr2) is False
496+
assert (arr1 == arr3) is False
497+
assert (arr1 != arr3) is True
498+
499+
487500
def test_array_from_buffers():
488501
values_buf = pa.py_buffer(np.int16([4, 5, 6, 7]))
489502
nulls_buf = pa.py_buffer(np.uint8([0b00001101]))

python/pyarrow/tests/test_compute.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -376,22 +376,22 @@ def con(values): return pa.chunked_array([values])
376376
arr1 = con([1, 2, 3, 4, None])
377377
arr2 = con([1, 1, 4, None, 4])
378378

379-
result = arr1 == arr2
379+
result = pc.equal(arr1, arr2)
380380
assert result.equals(con([True, False, False, None, None]))
381381

382-
result = arr1 != arr2
382+
result = pc.not_equal(arr1, arr2)
383383
assert result.equals(con([False, True, True, None, None]))
384384

385-
result = arr1 < arr2
385+
result = pc.less(arr1, arr2)
386386
assert result.equals(con([False, False, True, None, None]))
387387

388-
result = arr1 <= arr2
388+
result = pc.less_equal(arr1, arr2)
389389
assert result.equals(con([True, False, True, None, None]))
390390

391-
result = arr1 > arr2
391+
result = pc.greater(arr1, arr2)
392392
assert result.equals(con([False, True, False, None, None]))
393393

394-
result = arr1 >= arr2
394+
result = pc.greater_equal(arr1, arr2)
395395
assert result.equals(con([True, True, False, None, None]))
396396

397397

@@ -406,22 +406,22 @@ def con(values): return pa.chunked_array([values])
406406
# TODO this is a hacky way to construct a scalar ..
407407
scalar = pa.array([2]).sum()
408408

409-
result = arr == scalar
409+
result = pc.equal(arr, scalar)
410410
assert result.equals(con([False, True, False, None]))
411411

412-
result = arr != scalar
412+
result = pc.not_equal(arr, scalar)
413413
assert result.equals(con([True, False, True, None]))
414414

415-
result = arr < scalar
415+
result = pc.less(arr, scalar)
416416
assert result.equals(con([True, False, False, None]))
417417

418-
result = arr <= scalar
418+
result = pc.less_equal(arr, scalar)
419419
assert result.equals(con([True, True, False, None]))
420420

421-
result = arr > scalar
421+
result = pc.greater(arr, scalar)
422422
assert result.equals(con([False, False, True, None]))
423423

424-
result = arr >= scalar
424+
result = pc.greater_equal(arr, scalar)
425425
assert result.equals(con([False, True, True, None]))
426426

427427

@@ -432,11 +432,12 @@ def test_compare_chunked_array_mixed():
432432

433433
expected = pa.chunked_array([[True, True, True, True, None]])
434434

435-
for result in [
436-
arr == arr_chunked,
437-
arr_chunked == arr,
438-
arr_chunked == arr_chunked2,
435+
for left, right in [
436+
(arr, arr_chunked),
437+
(arr_chunked, arr),
438+
(arr_chunked, arr_chunked2),
439439
]:
440+
result = pc.equal(left, right)
440441
assert result.equals(expected)
441442

442443

python/pyarrow/tests/test_scalars.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_list(ty, klass):
346346
assert s.type == ty
347347
assert len(s) == 2
348348
assert isinstance(s.values, pa.Array)
349-
assert s.values == v
349+
assert s.values.to_pylist() == v
350350
assert isinstance(s, klass)
351351
assert repr(v) in repr(s)
352352
assert s.as_py() == v
@@ -496,7 +496,7 @@ def test_dictionary():
496496
assert s.as_py() == v
497497
assert s.value.as_py() == v
498498
assert s.index.as_py() == i
499-
assert s.dictionary == dictionary
499+
assert s.dictionary.to_pylist() == dictionary
500500

501501
with pytest.warns(FutureWarning):
502502
assert s.index_value.as_py() == i

python/pyarrow/tests/test_table.py

+3
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def eq(xarrs, yarrs):
128128
y = pa.chunked_array(yarrs)
129129
assert x.equals(y)
130130
assert y.equals(x)
131+
assert x == y
132+
assert x != str(y)
131133

132134
def ne(xarrs, yarrs):
133135
if isinstance(xarrs, pa.ChunkedArray):
@@ -140,6 +142,7 @@ def ne(xarrs, yarrs):
140142
y = pa.chunked_array(yarrs)
141143
assert not x.equals(y)
142144
assert not y.equals(x)
145+
assert x != y
143146

144147
eq(pa.chunked_array([], type=pa.int32()),
145148
pa.chunked_array([], type=pa.int32()))

0 commit comments

Comments
 (0)