Skip to content

Commit d9f57a4

Browse files
jbrockmendelharisbal
authored and
harisbal
committed
make ops.add_foo take just class (pandas-dev#19828)
1 parent e3c5467 commit d9f57a4

File tree

7 files changed

+127
-92
lines changed

7 files changed

+127
-92
lines changed

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6131,8 +6131,8 @@ def isin(self, values):
61316131
DataFrame._add_numeric_operations()
61326132
DataFrame._add_series_or_dataframe_operations()
61336133

6134-
ops.add_flex_arithmetic_methods(DataFrame, **ops.frame_flex_funcs)
6135-
ops.add_special_arithmetic_methods(DataFrame, **ops.frame_special_funcs)
6134+
ops.add_flex_arithmetic_methods(DataFrame)
6135+
ops.add_special_arithmetic_methods(DataFrame)
61366136

61376137

61386138
def _arrays_to_mgr(arrays, arr_names, index, columns, dtype=None):

pandas/core/ops.py

+115-69
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
construct_1d_object_array_from_listlike)
3838
from pandas.core.dtypes.generic import (
3939
ABCSeries,
40-
ABCDataFrame,
40+
ABCDataFrame, ABCPanel,
4141
ABCIndex,
4242
ABCSparseSeries, ABCSparseArray)
4343

@@ -711,6 +711,64 @@ def mask_cmp_op(x, y, op, allowed_types):
711711
# Functions that add arithmetic methods to objects, given arithmetic factory
712712
# methods
713713

714+
def _get_method_wrappers(cls):
715+
"""
716+
Find the appropriate operation-wrappers to use when defining flex/special
717+
arithmetic, boolean, and comparison operations with the given class.
718+
719+
Parameters
720+
----------
721+
cls : class
722+
723+
Returns
724+
-------
725+
arith_flex : function or None
726+
comp_flex : function or None
727+
arith_special : function
728+
comp_special : function
729+
bool_special : function
730+
731+
Notes
732+
-----
733+
None is only returned for SparseArray
734+
"""
735+
if issubclass(cls, ABCSparseSeries):
736+
# Be sure to catch this before ABCSeries and ABCSparseArray,
737+
# as they will both come see SparseSeries as a subclass
738+
arith_flex = _flex_method_SERIES
739+
comp_flex = _flex_method_SERIES
740+
arith_special = _arith_method_SPARSE_SERIES
741+
comp_special = _arith_method_SPARSE_SERIES
742+
bool_special = _bool_method_SERIES
743+
# TODO: I don't think the functions defined by bool_method are tested
744+
elif issubclass(cls, ABCSeries):
745+
# Just Series; SparseSeries is caught above
746+
arith_flex = _flex_method_SERIES
747+
comp_flex = _flex_method_SERIES
748+
arith_special = _arith_method_SERIES
749+
comp_special = _comp_method_SERIES
750+
bool_special = _bool_method_SERIES
751+
elif issubclass(cls, ABCSparseArray):
752+
arith_flex = None
753+
comp_flex = None
754+
arith_special = _arith_method_SPARSE_ARRAY
755+
comp_special = _arith_method_SPARSE_ARRAY
756+
bool_special = _arith_method_SPARSE_ARRAY
757+
elif issubclass(cls, ABCPanel):
758+
arith_flex = _flex_method_PANEL
759+
comp_flex = _comp_method_PANEL
760+
arith_special = _arith_method_PANEL
761+
comp_special = _comp_method_PANEL
762+
bool_special = _arith_method_PANEL
763+
elif issubclass(cls, ABCDataFrame):
764+
# Same for DataFrame and SparseDataFrame
765+
arith_flex = _arith_method_FRAME
766+
comp_flex = _flex_comp_method_FRAME
767+
arith_special = _arith_method_FRAME
768+
comp_special = _comp_method_FRAME
769+
bool_special = _arith_method_FRAME
770+
return arith_flex, comp_flex, arith_special, comp_special, bool_special
771+
714772

715773
def _create_methods(cls, arith_method, comp_method, bool_method,
716774
special=False):
@@ -743,16 +801,18 @@ def _create_methods(cls, arith_method, comp_method, bool_method,
743801
# yapf: enable
744802
new_methods['div'] = new_methods['truediv']
745803
new_methods['rdiv'] = new_methods['rtruediv']
804+
if have_divmod:
805+
# divmod doesn't have an op that is supported by numexpr
806+
new_methods['divmod'] = arith_method(cls, divmod, special)
807+
808+
new_methods.update(dict(
809+
eq=comp_method(cls, operator.eq, special),
810+
ne=comp_method(cls, operator.ne, special),
811+
lt=comp_method(cls, operator.lt, special),
812+
gt=comp_method(cls, operator.gt, special),
813+
le=comp_method(cls, operator.le, special),
814+
ge=comp_method(cls, operator.ge, special)))
746815

747-
# Comp methods never had a default axis set
748-
if comp_method:
749-
new_methods.update(dict(
750-
eq=comp_method(cls, operator.eq, special),
751-
ne=comp_method(cls, operator.ne, special),
752-
lt=comp_method(cls, operator.lt, special),
753-
gt=comp_method(cls, operator.gt, special),
754-
le=comp_method(cls, operator.le, special),
755-
ge=comp_method(cls, operator.ge, special)))
756816
if bool_method:
757817
new_methods.update(
758818
dict(and_=bool_method(cls, operator.and_, special),
@@ -762,9 +822,6 @@ def _create_methods(cls, arith_method, comp_method, bool_method,
762822
rand_=bool_method(cls, rand_, special),
763823
ror_=bool_method(cls, ror_, special),
764824
rxor=bool_method(cls, rxor, special)))
765-
if have_divmod:
766-
# divmod doesn't have an op that is supported by numexpr
767-
new_methods['divmod'] = arith_method(cls, divmod, special)
768825

769826
if special:
770827
dunderize = lambda x: '__{name}__'.format(name=x.strip('_'))
@@ -788,22 +845,17 @@ def add_methods(cls, new_methods):
788845

789846
# ----------------------------------------------------------------------
790847
# Arithmetic
791-
def add_special_arithmetic_methods(cls, arith_method=None,
792-
comp_method=None, bool_method=None):
848+
def add_special_arithmetic_methods(cls):
793849
"""
794850
Adds the full suite of special arithmetic methods (``__add__``,
795851
``__sub__``, etc.) to the class.
796852
797853
Parameters
798854
----------
799-
arith_method : function (optional)
800-
factory for special arithmetic methods:
801-
f(cls, op, special)
802-
comp_method : function (optional)
803-
factory for rich comparison - signature: f(cls, op, special)
804-
bool_method : function (optional)
805-
factory for boolean methods - signature: f(cls, op, special)
855+
cls : class
856+
special methods will be defined and pinned to this class
806857
"""
858+
_, _, arith_method, comp_method, bool_method = _get_method_wrappers(cls)
807859
new_methods = _create_methods(cls, arith_method, comp_method, bool_method,
808860
special=True)
809861
# inplace operators (I feel like these should get passed an `inplace=True`
@@ -836,28 +888,26 @@ def f(self, other):
836888
__ipow__=_wrap_inplace_method(new_methods["__pow__"])))
837889
if not compat.PY3:
838890
new_methods["__idiv__"] = _wrap_inplace_method(new_methods["__div__"])
839-
if bool_method:
840-
new_methods.update(
841-
dict(__iand__=_wrap_inplace_method(new_methods["__and__"]),
842-
__ior__=_wrap_inplace_method(new_methods["__or__"]),
843-
__ixor__=_wrap_inplace_method(new_methods["__xor__"])))
891+
892+
new_methods.update(
893+
dict(__iand__=_wrap_inplace_method(new_methods["__and__"]),
894+
__ior__=_wrap_inplace_method(new_methods["__or__"]),
895+
__ixor__=_wrap_inplace_method(new_methods["__xor__"])))
844896

845897
add_methods(cls, new_methods=new_methods)
846898

847899

848-
def add_flex_arithmetic_methods(cls, flex_arith_method, flex_comp_method=None):
900+
def add_flex_arithmetic_methods(cls):
849901
"""
850902
Adds the full suite of flex arithmetic methods (``pow``, ``mul``, ``add``)
851903
to the class.
852904
853905
Parameters
854906
----------
855-
flex_arith_method : function
856-
factory for flex arithmetic methods:
857-
f(cls, op, special)
858-
flex_comp_method : function, optional,
859-
factory for rich comparison - signature: f(cls, op, special)
907+
cls : class
908+
flex methods will be defined and pinned to this class
860909
"""
910+
flex_arith_method, flex_comp_method, _, _, _ = _get_method_wrappers(cls)
861911
new_methods = _create_methods(cls, flex_arith_method,
862912
flex_comp_method, bool_method=None,
863913
special=False)
@@ -1284,14 +1334,6 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0):
12841334
return flex_wrapper
12851335

12861336

1287-
series_flex_funcs = dict(flex_arith_method=_flex_method_SERIES,
1288-
flex_comp_method=_flex_method_SERIES)
1289-
1290-
series_special_funcs = dict(arith_method=_arith_method_SERIES,
1291-
comp_method=_comp_method_SERIES,
1292-
bool_method=_bool_method_SERIES)
1293-
1294-
12951337
# -----------------------------------------------------------------------------
12961338
# DataFrame
12971339

@@ -1533,14 +1575,6 @@ def f(self, other):
15331575
return f
15341576

15351577

1536-
frame_flex_funcs = dict(flex_arith_method=_arith_method_FRAME,
1537-
flex_comp_method=_flex_comp_method_FRAME)
1538-
1539-
frame_special_funcs = dict(arith_method=_arith_method_FRAME,
1540-
comp_method=_comp_method_FRAME,
1541-
bool_method=_arith_method_FRAME)
1542-
1543-
15441578
# -----------------------------------------------------------------------------
15451579
# Panel
15461580

@@ -1629,16 +1663,38 @@ def f(self, other, axis=0):
16291663
return f
16301664

16311665

1632-
panel_special_funcs = dict(arith_method=_arith_method_PANEL,
1633-
comp_method=_comp_method_PANEL,
1634-
bool_method=_arith_method_PANEL)
1635-
1636-
panel_flex_funcs = dict(flex_arith_method=_flex_method_PANEL,
1637-
flex_comp_method=_comp_method_PANEL)
1638-
16391666
# -----------------------------------------------------------------------------
16401667
# Sparse
16411668

1669+
def _cast_sparse_series_op(left, right, opname):
1670+
"""
1671+
For SparseSeries operation, coerce to float64 if the result is expected
1672+
to have NaN or inf values
1673+
1674+
Parameters
1675+
----------
1676+
left : SparseArray
1677+
right : SparseArray
1678+
opname : str
1679+
1680+
Returns
1681+
-------
1682+
left : SparseArray
1683+
right : SparseArray
1684+
"""
1685+
opname = opname.strip('_')
1686+
1687+
if is_integer_dtype(left) and is_integer_dtype(right):
1688+
# series coerces to float64 if result should have NaN/inf
1689+
if opname in ('floordiv', 'mod') and (right.values == 0).any():
1690+
left = left.astype(np.float64)
1691+
right = right.astype(np.float64)
1692+
elif opname in ('rfloordiv', 'rmod') and (left.values == 0).any():
1693+
left = left.astype(np.float64)
1694+
right = right.astype(np.float64)
1695+
1696+
return left, right
1697+
16421698

16431699
def _arith_method_SPARSE_SERIES(cls, op, special):
16441700
"""
@@ -1674,8 +1730,8 @@ def _sparse_series_op(left, right, op, name):
16741730
new_name = get_op_result_name(left, right)
16751731

16761732
from pandas.core.sparse.array import _sparse_array_op
1677-
result = _sparse_array_op(left.values, right.values, op, name,
1678-
series=True)
1733+
lvalues, rvalues = _cast_sparse_series_op(left.values, right.values, name)
1734+
result = _sparse_array_op(lvalues, rvalues, op, name)
16791735
return left._constructor(result, index=new_index, name=new_name)
16801736

16811737

@@ -1697,7 +1753,7 @@ def wrapper(self, other):
16971753
dtype = getattr(other, 'dtype', None)
16981754
other = SparseArray(other, fill_value=self.fill_value,
16991755
dtype=dtype)
1700-
return _sparse_array_op(self, other, op, name, series=False)
1756+
return _sparse_array_op(self, other, op, name)
17011757
elif is_scalar(other):
17021758
with np.errstate(all='ignore'):
17031759
fill = op(_get_fill(self), np.asarray(other))
@@ -1710,13 +1766,3 @@ def wrapper(self, other):
17101766

17111767
wrapper.__name__ = name
17121768
return wrapper
1713-
1714-
1715-
sparse_array_special_funcs = dict(arith_method=_arith_method_SPARSE_ARRAY,
1716-
comp_method=_arith_method_SPARSE_ARRAY,
1717-
bool_method=_arith_method_SPARSE_ARRAY)
1718-
1719-
sparse_series_special_funcs = dict(arith_method=_arith_method_SPARSE_SERIES,
1720-
comp_method=_arith_method_SPARSE_SERIES,
1721-
bool_method=_bool_method_SERIES)
1722-
# TODO: I don't think the functions defined by bool_method are tested

pandas/core/panel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1527,8 +1527,8 @@ def _extract_axis(self, data, axis=0, intersect=False):
15271527
slicers={'major_axis': 'index',
15281528
'minor_axis': 'columns'})
15291529

1530-
ops.add_special_arithmetic_methods(Panel, **ops.panel_special_funcs)
1531-
ops.add_flex_arithmetic_methods(Panel, **ops.panel_flex_funcs)
1530+
ops.add_special_arithmetic_methods(Panel)
1531+
ops.add_flex_arithmetic_methods(Panel)
15321532
Panel._add_numeric_operations()
15331533

15341534

pandas/core/series.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3100,8 +3100,8 @@ def to_period(self, freq=None, copy=True):
31003100
Series._add_series_or_dataframe_operations()
31013101

31023102
# Add arithmetic!
3103-
ops.add_flex_arithmetic_methods(Series, **ops.series_flex_funcs)
3104-
ops.add_special_arithmetic_methods(Series, **ops.series_special_funcs)
3103+
ops.add_flex_arithmetic_methods(Series)
3104+
ops.add_special_arithmetic_methods(Series)
31053105

31063106

31073107
# -----------------------------------------------------------------------------

pandas/core/sparse/array.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,11 @@ def _get_fill(arr):
5353
return np.asarray(arr.fill_value)
5454

5555

56-
def _sparse_array_op(left, right, op, name, series=False):
56+
def _sparse_array_op(left, right, op, name):
5757
if name.startswith('__'):
5858
# For lookups in _libs.sparse we need non-dunder op name
5959
name = name[2:-2]
6060

61-
if series and is_integer_dtype(left) and is_integer_dtype(right):
62-
# series coerces to float64 if result should have NaN/inf
63-
if name in ('floordiv', 'mod') and (right.values == 0).any():
64-
left = left.astype(np.float64)
65-
right = right.astype(np.float64)
66-
elif name in ('rfloordiv', 'rmod') and (left.values == 0).any():
67-
left = left.astype(np.float64)
68-
right = right.astype(np.float64)
69-
7061
# dtype used to find corresponding sparse method
7162
if not is_dtype_equal(left.dtype, right.dtype):
7263
dtype = find_common_type([left.dtype, right.dtype])
@@ -850,5 +841,4 @@ def _make_index(length, indices, kind):
850841
return index
851842

852843

853-
ops.add_special_arithmetic_methods(SparseArray,
854-
**ops.sparse_array_special_funcs)
844+
ops.add_special_arithmetic_methods(SparseArray)

pandas/core/sparse/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1014,5 +1014,5 @@ def homogenize(series_dict):
10141014

10151015

10161016
# use unaccelerated ops for sparse objects
1017-
ops.add_flex_arithmetic_methods(SparseDataFrame, **ops.frame_flex_funcs)
1018-
ops.add_special_arithmetic_methods(SparseDataFrame, **ops.frame_special_funcs)
1017+
ops.add_flex_arithmetic_methods(SparseDataFrame)
1018+
ops.add_special_arithmetic_methods(SparseDataFrame)

pandas/core/sparse/series.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,5 @@ def from_coo(cls, A, dense_index=False):
812812

813813

814814
# overwrite series methods with unaccelerated Sparse-specific versions
815-
ops.add_flex_arithmetic_methods(SparseSeries, **ops.series_flex_funcs)
816-
ops.add_special_arithmetic_methods(SparseSeries,
817-
**ops.sparse_series_special_funcs)
815+
ops.add_flex_arithmetic_methods(SparseSeries)
816+
ops.add_special_arithmetic_methods(SparseSeries)

0 commit comments

Comments
 (0)