Skip to content

Commit 1cd62b8

Browse files
TomAugspurgerproost
authored andcommitted
API: Handle pow & rpow special cases (pandas-dev#30097)
1 parent 5d96fad commit 1cd62b8

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

doc/source/user_guide/missing_data.rst

+12
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,18 @@ For example, ``pd.NA`` propagates in arithmetic operations, similarly to
822822
pd.NA + 1
823823
"a" * pd.NA
824824
825+
There are a few special cases when the result is known, even when one of the
826+
operands is ``NA``.
827+
828+
829+
================ ======
830+
Operation Result
831+
================ ======
832+
``pd.NA ** 0`` 0
833+
``1 ** pd.NA`` 1
834+
``-1 ** pd.NA`` -1
835+
================ ======
836+
825837
In equality and comparison operations, ``pd.NA`` also propagates. This deviates
826838
from the behaviour of ``np.nan``, where comparisons with ``np.nan`` always
827839
return ``False``.

pandas/_libs/missing.pyx

+24-2
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,6 @@ class NAType(C_NAType):
365365
__rmod__ = _create_binary_propagating_op("__rmod__")
366366
__divmod__ = _create_binary_propagating_op("__divmod__", divmod=True)
367367
__rdivmod__ = _create_binary_propagating_op("__rdivmod__", divmod=True)
368-
__pow__ = _create_binary_propagating_op("__pow__")
369-
__rpow__ = _create_binary_propagating_op("__rpow__")
370368
# __lshift__ and __rshift__ are not implemented
371369

372370
__eq__ = _create_binary_propagating_op("__eq__")
@@ -383,6 +381,30 @@ class NAType(C_NAType):
383381
__abs__ = _create_unary_propagating_op("__abs__")
384382
__invert__ = _create_unary_propagating_op("__invert__")
385383

384+
# pow has special
385+
def __pow__(self, other):
386+
if other is C_NA:
387+
return NA
388+
elif isinstance(other, (numbers.Number, np.bool_)):
389+
if other == 0:
390+
# returning positive is correct for +/- 0.
391+
return type(other)(1)
392+
else:
393+
return NA
394+
395+
return NotImplemented
396+
397+
def __rpow__(self, other):
398+
if other is C_NA:
399+
return NA
400+
elif isinstance(other, (numbers.Number, np.bool_)):
401+
if other == 1 or other == -1:
402+
return other
403+
else:
404+
return NA
405+
406+
return NotImplemented
407+
386408
# Logical ops using Kleene logic
387409

388410
def __and__(self, other):

pandas/tests/scalar/test_na_scalar.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ def test_arithmetic_ops(all_arithmetic_functions):
3838
op = all_arithmetic_functions
3939

4040
for other in [NA, 1, 1.0, "a", np.int64(1), np.nan]:
41-
if op.__name__ == "rmod" and isinstance(other, str):
41+
if op.__name__ in ("pow", "rpow", "rmod") and isinstance(other, str):
4242
continue
4343
if op.__name__ in ("divmod", "rdivmod"):
4444
assert op(NA, other) is (NA, NA)
4545
else:
46+
if op.__name__ == "rpow":
47+
# avoid special case
48+
other += 1
4649
assert op(NA, other) is NA
4750

4851

@@ -69,6 +72,49 @@ def test_comparison_ops():
6972
assert (other <= NA) is NA
7073

7174

75+
@pytest.mark.parametrize(
76+
"value",
77+
[
78+
0,
79+
0.0,
80+
-0,
81+
-0.0,
82+
False,
83+
np.bool_(False),
84+
np.int_(0),
85+
np.float_(0),
86+
np.int_(-0),
87+
np.float_(-0),
88+
],
89+
)
90+
def test_pow_special(value):
91+
result = pd.NA ** value
92+
assert isinstance(result, type(value))
93+
assert result == 1
94+
95+
96+
@pytest.mark.parametrize(
97+
"value",
98+
[
99+
1,
100+
1.0,
101+
-1,
102+
-1.0,
103+
True,
104+
np.bool_(True),
105+
np.int_(1),
106+
np.float_(1),
107+
np.int_(-1),
108+
np.float_(-1),
109+
],
110+
)
111+
def test_rpow_special(value):
112+
result = value ** pd.NA
113+
assert result == value
114+
if not isinstance(value, (np.float_, np.bool_, np.int_)):
115+
assert isinstance(result, type(value))
116+
117+
72118
def test_unary_ops():
73119
assert +NA is NA
74120
assert -NA is NA

0 commit comments

Comments
 (0)