Skip to content

Commit bf8680f

Browse files
committed
Implement NA.__array_ufunc__
This gives us consistent comparisions with NumPy scalars.
1 parent e28ebe3 commit bf8680f

File tree

4 files changed

+132
-101
lines changed

4 files changed

+132
-101
lines changed

pandas/_libs/missing.pyx

+115-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def _create_binary_propagating_op(name, divmod=False):
290290

291291
def method(self, other):
292292
if (other is C_NA or isinstance(other, str)
293-
or isinstance(other, (numbers.Number, np.bool_))):
293+
or isinstance(other, (numbers.Number, np.bool_, np.int64, np.int_))
294+
or isinstance(other, np.ndarray) and not other.shape):
294295
if divmod:
295296
return NA, NA
296297
else:
@@ -310,6 +311,98 @@ def _create_unary_propagating_op(name):
310311
return method
311312

312313

314+
def maybe_dispatch_ufunc_to_dunder_op(
315+
object self, object ufunc, str method, *inputs, **kwargs
316+
):
317+
"""
318+
Dispatch a ufunc to the equivalent dunder method.
319+
320+
Parameters
321+
----------
322+
self : ArrayLike
323+
The array whose dunder method we dispatch to
324+
ufunc : Callable
325+
A NumPy ufunc
326+
method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'}
327+
inputs : ArrayLike
328+
The input arrays.
329+
kwargs : Any
330+
The additional keyword arguments, e.g. ``out``.
331+
332+
Returns
333+
-------
334+
result : Any
335+
The result of applying the ufunc
336+
"""
337+
# special has the ufuncs we dispatch to the dunder op on
338+
special = {
339+
"add",
340+
"sub",
341+
"mul",
342+
"pow",
343+
"mod",
344+
"floordiv",
345+
"truediv",
346+
"divmod",
347+
"eq",
348+
"ne",
349+
"lt",
350+
"gt",
351+
"le",
352+
"ge",
353+
"remainder",
354+
"matmul",
355+
"or",
356+
"xor",
357+
"and",
358+
}
359+
aliases = {
360+
"subtract": "sub",
361+
"multiply": "mul",
362+
"floor_divide": "floordiv",
363+
"true_divide": "truediv",
364+
"power": "pow",
365+
"remainder": "mod",
366+
"divide": "div",
367+
"equal": "eq",
368+
"not_equal": "ne",
369+
"less": "lt",
370+
"less_equal": "le",
371+
"greater": "gt",
372+
"greater_equal": "ge",
373+
"bitwise_or": "or",
374+
"bitwise_and": "and",
375+
"bitwise_xor": "xor",
376+
}
377+
378+
# For op(., Array) -> Array.__r{op}__
379+
flipped = {
380+
"lt": "__gt__",
381+
"le": "__ge__",
382+
"gt": "__lt__",
383+
"ge": "__le__",
384+
"eq": "__eq__",
385+
"ne": "__ne__",
386+
}
387+
388+
op_name = ufunc.__name__
389+
op_name = aliases.get(op_name, op_name)
390+
391+
def not_implemented(*args, **kwargs):
392+
return NotImplemented
393+
394+
if method == "__call__" and op_name in special and kwargs.get("out") is None:
395+
if isinstance(inputs[0], type(self)):
396+
name = "__{}__".format(op_name)
397+
return getattr(self, name, not_implemented)(inputs[1])
398+
else:
399+
name = flipped.get(op_name, "__r{}__".format(op_name))
400+
result = getattr(self, name, not_implemented)(inputs[0])
401+
return result
402+
else:
403+
return NotImplemented
404+
405+
313406
cdef class C_NAType:
314407
pass
315408

@@ -434,6 +527,27 @@ class NAType(C_NAType):
434527

435528
__rxor__ = __xor__
436529

530+
# What else to add here? datetime / Timestamp? Period? Interval?
531+
# Note: we only handle 0-d ndarrays.
532+
__array_priority__ = 1000
533+
_HANDLED_TYPES = (np.ndarray, numbers.Number, str, np.bool_, np.int64)
534+
535+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
536+
types = self._HANDLED_TYPES + (NAType,)
537+
for x in inputs:
538+
if not isinstance(x, types):
539+
return NotImplemented
540+
541+
if method != "__call__":
542+
raise ValueError(f"ufunc method '{method}' not supported for NA")
543+
result = maybe_dispatch_ufunc_to_dunder_op(self, ufunc, method, *inputs, **kwargs)
544+
if result is NotImplemented:
545+
if ufunc.nout == 1:
546+
result = NA
547+
else:
548+
result = (NA,) * ufunc.nout
549+
return result
550+
437551

438552
C_NA = NAType() # C-visible
439553
NA = C_NA # Python-visible

pandas/core/ops/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
from pandas._libs import Timedelta, Timestamp, lib
13+
from pandas._libs.missing import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
1314
from pandas.util._decorators import Appender
1415

1516
from pandas.core.dtypes.common import is_list_like, is_timedelta64_dtype
@@ -30,7 +31,6 @@
3031
)
3132
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
3233
from pandas.core.ops.common import unpack_zerodim_and_defer
33-
from pandas.core.ops.dispatch import maybe_dispatch_ufunc_to_dunder_op # noqa:F401
3434
from pandas.core.ops.dispatch import should_series_dispatch
3535
from pandas.core.ops.docstrings import (
3636
_arith_doc_FRAME,

pandas/core/ops/dispatch.py

+1-93
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Functions for defining unary operations.
33
"""
4-
from typing import Any, Callable, Union
4+
from typing import Any, Union
55

66
import numpy as np
77

@@ -17,7 +17,6 @@
1717
)
1818
from pandas.core.dtypes.generic import ABCExtensionArray, ABCSeries
1919

20-
from pandas._typing import ArrayLike
2120
from pandas.core.construction import array
2221

2322

@@ -146,94 +145,3 @@ def dispatch_to_extension_op(
146145
"operation [{name}]".format(name=op.__name__)
147146
)
148147
return res_values
149-
150-
151-
def maybe_dispatch_ufunc_to_dunder_op(
152-
self: ArrayLike, ufunc: Callable, method: str, *inputs: ArrayLike, **kwargs: Any
153-
):
154-
"""
155-
Dispatch a ufunc to the equivalent dunder method.
156-
157-
Parameters
158-
----------
159-
self : ArrayLike
160-
The array whose dunder method we dispatch to
161-
ufunc : Callable
162-
A NumPy ufunc
163-
method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'}
164-
inputs : ArrayLike
165-
The input arrays.
166-
kwargs : Any
167-
The additional keyword arguments, e.g. ``out``.
168-
169-
Returns
170-
-------
171-
result : Any
172-
The result of applying the ufunc
173-
"""
174-
# special has the ufuncs we dispatch to the dunder op on
175-
special = {
176-
"add",
177-
"sub",
178-
"mul",
179-
"pow",
180-
"mod",
181-
"floordiv",
182-
"truediv",
183-
"divmod",
184-
"eq",
185-
"ne",
186-
"lt",
187-
"gt",
188-
"le",
189-
"ge",
190-
"remainder",
191-
"matmul",
192-
"or",
193-
"xor",
194-
"and",
195-
}
196-
aliases = {
197-
"subtract": "sub",
198-
"multiply": "mul",
199-
"floor_divide": "floordiv",
200-
"true_divide": "truediv",
201-
"power": "pow",
202-
"remainder": "mod",
203-
"divide": "div",
204-
"equal": "eq",
205-
"not_equal": "ne",
206-
"less": "lt",
207-
"less_equal": "le",
208-
"greater": "gt",
209-
"greater_equal": "ge",
210-
"bitwise_or": "or",
211-
"bitwise_and": "and",
212-
"bitwise_xor": "xor",
213-
}
214-
215-
# For op(., Array) -> Array.__r{op}__
216-
flipped = {
217-
"lt": "__gt__",
218-
"le": "__ge__",
219-
"gt": "__lt__",
220-
"ge": "__le__",
221-
"eq": "__eq__",
222-
"ne": "__ne__",
223-
}
224-
225-
op_name = ufunc.__name__
226-
op_name = aliases.get(op_name, op_name)
227-
228-
def not_implemented(*args, **kwargs):
229-
return NotImplemented
230-
231-
if method == "__call__" and op_name in special and kwargs.get("out") is None:
232-
if isinstance(inputs[0], type(self)):
233-
name = "__{}__".format(op_name)
234-
return getattr(self, name, not_implemented)(inputs[1])
235-
else:
236-
name = flipped.get(op_name, "__r{}__".format(op_name))
237-
return getattr(self, name, not_implemented)(inputs[0])
238-
else:
239-
return NotImplemented

pandas/tests/scalar/test_na_scalar.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ def test_comparison_ops():
5858
assert (NA >= other) is NA
5959
assert (NA < other) is NA
6060
assert (NA <= other) is NA
61-
62-
if isinstance(other, (np.int64, np.bool_)):
63-
# for numpy scalars we get a deprecation warning and False as result
64-
# for equality or error for larger/lesser than
65-
continue
66-
6761
assert (other == NA) is NA
6862
assert (other != NA) is NA
6963
assert (other > NA) is NA
@@ -175,3 +169,18 @@ def test_series_isna():
175169
s = pd.Series([1, NA], dtype=object)
176170
expected = pd.Series([False, True])
177171
tm.assert_series_equal(s.isna(), expected)
172+
173+
174+
def test_ufunc():
175+
assert np.log(pd.NA) is pd.NA
176+
assert np.add(pd.NA, 1) is pd.NA
177+
result = np.divmod(pd.NA, 1)
178+
assert result[0] is pd.NA and result[1] is pd.NA
179+
180+
result = np.frexp(pd.NA)
181+
assert result[0] is pd.NA and result[1] is pd.NA
182+
183+
184+
def test_ufunc_raises():
185+
with pytest.raises(ValueError, match="ufunc method 'at'"):
186+
np.log.at(pd.NA, 0)

0 commit comments

Comments
 (0)