Skip to content

Commit 8103b8c

Browse files
committed
simplify EA._reduce
1 parent 6d40678 commit 8103b8c

File tree

4 files changed

+23
-36
lines changed

4 files changed

+23
-36
lines changed

pandas/core/arrays/base.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -712,23 +712,17 @@ def _add_comparison_ops(cls):
712712
cls.__le__ = cls._create_comparison_method(operator.le)
713713
cls.__ge__ = cls._create_comparison_method(operator.ge)
714714

715-
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
716-
filter_type=None, **kwargs):
715+
def _reduce(self, name, skipna=True, **kwargs):
717716
"""Return a scalar result of performing the op
718717
719718
Parameters
720719
----------
721-
op : callable
722-
function to apply to the array
723720
name : str
724721
name of the function
725722
axis : int, default 0
726723
axis over which to apply, defined as 0 currently
727724
skipna : bool, default True
728725
if True, skip NaN values
729-
numeric_only : bool, optional
730-
if True, only perform numeric ops
731-
filter_type : str, optional
732726
kwargs : dict
733727
734728
Returns

pandas/core/arrays/integer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas.compat import u, range, string_types
99
from pandas.compat import set_function_name
1010

11+
from pandas.core import nanops
1112
from pandas.core.dtypes.cast import astype_nansafe
1213
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
1314
from pandas.core.dtypes.common import (
@@ -529,8 +530,7 @@ def cmp_method(self, other):
529530
name = '__{name}__'.format(name=op.__name__)
530531
return set_function_name(cmp_method, name, cls)
531532

532-
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
533-
filter_type=None, **kwds):
533+
def _reduce(self, name, axis=0, skipna=True, **kwargs):
534534
data = self._data
535535
mask = self._mask
536536

@@ -539,6 +539,7 @@ def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
539539
data = self._data.astype('float64')
540540
data[mask] = self._na_value
541541

542+
op = getattr(nanops, 'nan' + name)
542543
result = op(data, axis=axis, skipna=skipna)
543544

544545
# if we have a boolean op, provide coercion back to a bool

pandas/core/series.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -3342,10 +3342,16 @@ def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
33423342
33433343
"""
33443344
delegate = self._values
3345-
if isinstance(delegate, np.ndarray):
3346-
# Validate that 'axis' is consistent with Series's single axis.
3347-
if axis is not None:
3348-
self._get_axis_number(axis)
3345+
3346+
if axis is not None:
3347+
self._get_axis_number(axis)
3348+
3349+
# dispatch to ExtensionArray interface
3350+
if isinstance(delegate, ExtensionArray):
3351+
return delegate._reduce(name, skipna=skipna, **kwds)
3352+
3353+
# dispatch to numpy arrays
3354+
elif isinstance(delegate, np.ndarray):
33493355
if numeric_only:
33503356
raise NotImplementedError('Series.{0} does not implement '
33513357
'numeric_only.'.format(name))

pandas/tests/extension/decimal/array.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -137,30 +137,16 @@ def _na_value(self):
137137
def _concat_same_type(cls, to_concat):
138138
return cls(np.concatenate([x._data for x in to_concat]))
139139

140-
def _reduce(self, op, name, axis=0, skipna=True, numeric_only=None,
141-
filter_type=None, **kwds):
142-
"""Return a scalar result of performing the op
143-
144-
Parameters
145-
----------
146-
op : callable
147-
function to apply to the array
148-
name : str
149-
name of the function
150-
axis : int, default 0
151-
axis over which to apply, defined as 0 currently
152-
skipna : bool, default True
153-
if True, skip NaN values
154-
numeric_only : bool, optional
155-
if True, only perform numeric ops
156-
filter_type : str, optional
157-
kwds : dict
140+
def _reduce(self, name, axis=0, skipna=True, **kwargs):
158141

159-
Returns
160-
-------
161-
scalar
162-
"""
163-
return op(self.data, axis=axis, skipna=skipna)
142+
# select the nan* ops
143+
if skipna:
144+
op = getattr(self.data, 'nan' + name)
145+
146+
# don't skip nans
147+
else:
148+
op = getattr(self.data, name)
149+
return op(axis=axis)
164150

165151

166152
DecimalArray._add_arithmetic_ops()

0 commit comments

Comments
 (0)