|
2 | 2 | Functions for arithmetic and comparison operations on NumPy arrays and
|
3 | 3 | ExtensionArrays.
|
4 | 4 | """
|
| 5 | +from functools import partial |
5 | 6 | import operator
|
6 |
| -from typing import Any, Union |
| 7 | +from typing import Any, Optional, Union |
7 | 8 |
|
8 | 9 | import numpy as np
|
9 | 10 |
|
@@ -51,10 +52,10 @@ def comp_method_OBJECT_ARRAY(op, x, y):
|
51 | 52 | if isinstance(y, (ABCSeries, ABCIndex)):
|
52 | 53 | y = y.values
|
53 | 54 |
|
54 |
| - result = libops.vec_compare(x, y, op) |
| 55 | + result = libops.vec_compare(x.ravel(), y, op) |
55 | 56 | else:
|
56 |
| - result = libops.scalar_compare(x, y, op) |
57 |
| - return result |
| 57 | + result = libops.scalar_compare(x.ravel(), y, op) |
| 58 | + return result.reshape(x.shape) |
58 | 59 |
|
59 | 60 |
|
60 | 61 | def masked_arith_op(x, y, op):
|
@@ -237,9 +238,9 @@ def comparison_op(
|
237 | 238 | elif is_scalar(rvalues) and isna(rvalues):
|
238 | 239 | # numpy does not like comparisons vs None
|
239 | 240 | if op is operator.ne:
|
240 |
| - res_values = np.ones(len(lvalues), dtype=bool) |
| 241 | + res_values = np.ones(lvalues.shape, dtype=bool) |
241 | 242 | else:
|
242 |
| - res_values = np.zeros(len(lvalues), dtype=bool) |
| 243 | + res_values = np.zeros(lvalues.shape, dtype=bool) |
243 | 244 |
|
244 | 245 | elif is_object_dtype(lvalues.dtype):
|
245 | 246 | res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)
|
@@ -367,3 +368,27 @@ def fill_bool(x, left=None):
|
367 | 368 | res_values = filler(res_values) # type: ignore
|
368 | 369 |
|
369 | 370 | return res_values
|
| 371 | + |
| 372 | + |
| 373 | +def get_array_op(op, str_rep: Optional[str] = None): |
| 374 | + """ |
| 375 | + Return a binary array operation corresponding to the given operator op. |
| 376 | +
|
| 377 | + Parameters |
| 378 | + ---------- |
| 379 | + op : function |
| 380 | + Binary operator from operator or roperator module. |
| 381 | + str_rep : str or None, default None |
| 382 | + str_rep to pass to arithmetic_op |
| 383 | +
|
| 384 | + Returns |
| 385 | + ------- |
| 386 | + function |
| 387 | + """ |
| 388 | + op_name = op.__name__.strip("_") |
| 389 | + if op_name in {"eq", "ne", "lt", "le", "gt", "ge"}: |
| 390 | + return partial(comparison_op, op=op) |
| 391 | + elif op_name in {"and", "or", "xor", "rand", "ror", "rxor"}: |
| 392 | + return partial(logical_op, op=op) |
| 393 | + else: |
| 394 | + return partial(arithmetic_op, op=op, str_rep=str_rep) |
0 commit comments