28
28
ABCDatetimeArray ,
29
29
ABCExtensionArray ,
30
30
ABCIndex ,
31
- ABCIndexClass ,
32
31
ABCSeries ,
33
32
ABCTimedeltaArray ,
34
33
)
@@ -53,13 +52,15 @@ def comp_method_OBJECT_ARRAY(op, x, y):
53
52
if isinstance (y , (ABCSeries , ABCIndex )):
54
53
y = y .values
55
54
56
- result = libops .vec_compare (x .ravel (), y , op )
55
+ if x .shape != y .shape :
56
+ raise ValueError ("Shapes must match" , x .shape , y .shape )
57
+ result = libops .vec_compare (x .ravel (), y .ravel (), op )
57
58
else :
58
59
result = libops .scalar_compare (x .ravel (), y , op )
59
60
return result .reshape (x .shape )
60
61
61
62
62
- def masked_arith_op (x , y , op ):
63
+ def masked_arith_op (x : np . ndarray , y , op ):
63
64
"""
64
65
If the given arithmetic operation fails, attempt it again on
65
66
only the non-null elements of the input array(s).
@@ -78,10 +79,22 @@ def masked_arith_op(x, y, op):
78
79
dtype = find_common_type ([x .dtype , y .dtype ])
79
80
result = np .empty (x .size , dtype = dtype )
80
81
82
+ if len (x ) != len (y ):
83
+ if not _can_broadcast (x , y ):
84
+ raise ValueError (x .shape , y .shape )
85
+
86
+ # Call notna on pre-broadcasted y for performance
87
+ ymask = notna (y )
88
+ y = np .broadcast_to (y , x .shape )
89
+ ymask = np .broadcast_to (ymask , x .shape )
90
+
91
+ else :
92
+ ymask = notna (y )
93
+
81
94
# NB: ravel() is only safe since y is ndarray; for e.g. PeriodIndex
82
95
# we would get int64 dtype, see GH#19956
83
96
yrav = y .ravel ()
84
- mask = notna (xrav ) & notna ( yrav )
97
+ mask = notna (xrav ) & ymask . ravel ( )
85
98
86
99
if yrav .shape != mask .shape :
87
100
# FIXME: GH#5284, GH#5035, GH#19448
@@ -211,6 +224,51 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
211
224
return res_values
212
225
213
226
227
+ def _broadcast_comparison_op (lvalues , rvalues , op ) -> np .ndarray :
228
+ """
229
+ Broadcast a comparison operation between two 2D arrays.
230
+
231
+ Parameters
232
+ ----------
233
+ lvalues : np.ndarray or ExtensionArray
234
+ rvalues : np.ndarray or ExtensionArray
235
+
236
+ Returns
237
+ -------
238
+ np.ndarray[bool]
239
+ """
240
+ if isinstance (rvalues , np .ndarray ):
241
+ rvalues = np .broadcast_to (rvalues , lvalues .shape )
242
+ result = comparison_op (lvalues , rvalues , op )
243
+ else :
244
+ result = np .empty (lvalues .shape , dtype = bool )
245
+ for i in range (len (lvalues )):
246
+ result [i , :] = comparison_op (lvalues [i ], rvalues [:, 0 ], op )
247
+ return result
248
+
249
+
250
+ def _can_broadcast (lvalues , rvalues ) -> bool :
251
+ """
252
+ Check if we can broadcast rvalues to match the shape of lvalues.
253
+
254
+ Parameters
255
+ ----------
256
+ lvalues : np.ndarray or ExtensionArray
257
+ rvalues : np.ndarray or ExtensionArray
258
+
259
+ Returns
260
+ -------
261
+ bool
262
+ """
263
+ # We assume that lengths dont match
264
+ if lvalues .ndim == rvalues .ndim == 2 :
265
+ # See if we can broadcast unambiguously
266
+ if lvalues .shape [1 ] == rvalues .shape [- 1 ]:
267
+ if rvalues .shape [0 ] == 1 :
268
+ return True
269
+ return False
270
+
271
+
214
272
def comparison_op (
215
273
left : ArrayLike , right : Any , op , str_rep : Optional [str ] = None ,
216
274
) -> ArrayLike :
@@ -237,12 +295,16 @@ def comparison_op(
237
295
# TODO: same for tuples?
238
296
rvalues = np .asarray (rvalues )
239
297
240
- if isinstance (rvalues , (np .ndarray , ABCExtensionArray , ABCIndexClass )):
298
+ if isinstance (rvalues , (np .ndarray , ABCExtensionArray )):
241
299
# TODO: make this treatment consistent across ops and classes.
242
300
# We are not catching all listlikes here (e.g. frozenset, tuple)
243
301
# The ambiguous case is object-dtype. See GH#27803
244
302
if len (lvalues ) != len (rvalues ):
245
- raise ValueError ("Lengths must match to compare" )
303
+ if _can_broadcast (lvalues , rvalues ):
304
+ return _broadcast_comparison_op (lvalues , rvalues , op )
305
+ raise ValueError (
306
+ "Lengths must match to compare" , lvalues .shape , rvalues .shape
307
+ )
246
308
247
309
if should_extension_dispatch (lvalues , rvalues ):
248
310
res_values = dispatch_to_extension_op (op , lvalues , rvalues )
0 commit comments