Skip to content

Commit 97478b0

Browse files
jbrockmendeljreback
authored andcommitted
implement masked_arith_op to de-duplicate ops code (#22182)
1 parent 1543a75 commit 97478b0

File tree

1 file changed

+56
-50
lines changed

1 file changed

+56
-50
lines changed

pandas/core/ops.py

+56-50
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pandas.core.dtypes.common import (
2828
needs_i8_conversion,
2929
is_datetimelike_v_numeric,
30+
is_period_dtype,
3031
is_integer_dtype, is_categorical_dtype,
3132
is_object_dtype, is_timedelta64_dtype,
3233
is_datetime64_dtype, is_datetime64tz_dtype,
@@ -41,7 +42,7 @@
4142
from pandas.core.dtypes.generic import (
4243
ABCSeries,
4344
ABCDataFrame, ABCPanel,
44-
ABCIndex,
45+
ABCIndex, ABCIndexClass,
4546
ABCSparseSeries, ABCSparseArray)
4647

4748

@@ -788,6 +789,57 @@ def mask_cmp_op(x, y, op, allowed_types):
788789
return result
789790

790791

792+
def masked_arith_op(x, y, op):
793+
"""
794+
If the given arithmetic operation fails, attempt it again on
795+
only the non-null elements of the input array(s).
796+
797+
Parameters
798+
----------
799+
x : np.ndarray
800+
y : np.ndarray, Series, Index
801+
op : binary operator
802+
"""
803+
# For Series `x` is 1D so ravel() is a no-op; calling it anyway makes
804+
# the logic valid for both Series and DataFrame ops.
805+
xrav = x.ravel()
806+
assert isinstance(x, (np.ndarray, ABCSeries)), type(x)
807+
if isinstance(y, (np.ndarray, ABCSeries, ABCIndexClass)):
808+
dtype = find_common_type([x.dtype, y.dtype])
809+
result = np.empty(x.size, dtype=dtype)
810+
811+
# PeriodIndex.ravel() returns int64 dtype, so we have
812+
# to work around that case. See GH#19956
813+
yrav = y if is_period_dtype(y) else y.ravel()
814+
mask = notna(xrav) & notna(yrav)
815+
816+
if yrav.shape != mask.shape:
817+
# FIXME: GH#5284, GH#5035, GH#19448
818+
# Without specifically raising here we get mismatched
819+
# errors in Py3 (TypeError) vs Py2 (ValueError)
820+
# Note: Only = an issue in DataFrame case
821+
raise ValueError('Cannot broadcast operands together.')
822+
823+
if mask.any():
824+
with np.errstate(all='ignore'):
825+
result[mask] = op(xrav[mask],
826+
com.values_from_object(yrav[mask]))
827+
828+
else:
829+
assert is_scalar(y), type(y)
830+
assert isinstance(x, np.ndarray), type(x)
831+
# mask is only meaningful for x
832+
result = np.empty(x.size, dtype=x.dtype)
833+
mask = notna(xrav)
834+
if mask.any():
835+
with np.errstate(all='ignore'):
836+
result[mask] = op(xrav[mask], y)
837+
838+
result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
839+
result = result.reshape(x.shape) # 2D compat
840+
return result
841+
842+
791843
def invalid_comparison(left, right, op):
792844
"""
793845
If a comparison has mismatched types and is not necessarily meaningful,
@@ -880,8 +932,7 @@ def _get_method_wrappers(cls):
880932
return arith_flex, comp_flex, arith_special, comp_special, bool_special
881933

882934

883-
def _create_methods(cls, arith_method, comp_method, bool_method,
884-
special=False):
935+
def _create_methods(cls, arith_method, comp_method, bool_method, special):
885936
# creates actual methods based upon arithmetic, comp and bool method
886937
# constructors.
887938

@@ -1136,19 +1187,7 @@ def na_op(x, y):
11361187
try:
11371188
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
11381189
except TypeError:
1139-
if isinstance(y, (np.ndarray, ABCSeries, pd.Index)):
1140-
dtype = find_common_type([x.dtype, y.dtype])
1141-
result = np.empty(x.size, dtype=dtype)
1142-
mask = notna(x) & notna(y)
1143-
result[mask] = op(x[mask], com.values_from_object(y[mask]))
1144-
else:
1145-
assert isinstance(x, np.ndarray)
1146-
assert is_scalar(y)
1147-
result = np.empty(len(x), dtype=x.dtype)
1148-
mask = notna(x)
1149-
result[mask] = op(x[mask], y)
1150-
1151-
result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
1190+
result = masked_arith_op(x, y, op)
11521191

11531192
result = missing.fill_zeros(result, x, y, op_name, fill_zeros)
11541193
return result
@@ -1675,40 +1714,7 @@ def na_op(x, y):
16751714
try:
16761715
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
16771716
except TypeError:
1678-
xrav = x.ravel()
1679-
if isinstance(y, (np.ndarray, ABCSeries)):
1680-
dtype = find_common_type([x.dtype, y.dtype])
1681-
result = np.empty(x.size, dtype=dtype)
1682-
yrav = y.ravel()
1683-
mask = notna(xrav) & notna(yrav)
1684-
xrav = xrav[mask]
1685-
1686-
if yrav.shape != mask.shape:
1687-
# FIXME: GH#5284, GH#5035, GH#19448
1688-
# Without specifically raising here we get mismatched
1689-
# errors in Py3 (TypeError) vs Py2 (ValueError)
1690-
raise ValueError('Cannot broadcast operands together.')
1691-
1692-
yrav = yrav[mask]
1693-
if xrav.size:
1694-
with np.errstate(all='ignore'):
1695-
result[mask] = op(xrav, yrav)
1696-
1697-
elif isinstance(x, np.ndarray):
1698-
# mask is only meaningful for x
1699-
result = np.empty(x.size, dtype=x.dtype)
1700-
mask = notna(xrav)
1701-
xrav = xrav[mask]
1702-
if xrav.size:
1703-
with np.errstate(all='ignore'):
1704-
result[mask] = op(xrav, y)
1705-
else:
1706-
raise TypeError("cannot perform operation {op} between "
1707-
"objects of type {x} and {y}"
1708-
.format(op=op_name, x=type(x), y=type(y)))
1709-
1710-
result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
1711-
result = result.reshape(x.shape)
1717+
result = masked_arith_op(x, y, op)
17121718

17131719
result = missing.fill_zeros(result, x, y, op_name, fill_zeros)
17141720

0 commit comments

Comments
 (0)