Skip to content

Commit a19509e

Browse files
jorisvandenbosschePuneethaPai
authored andcommitted
CLN: deduplicate __setitem__ and _reduce on masked arrays (pandas-dev#34187)
1 parent 2522339 commit a19509e

File tree

3 files changed

+53
-73
lines changed

3 files changed

+53
-73
lines changed

pandas/core/arrays/boolean.py

+4-35
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,13 @@
1818
is_integer_dtype,
1919
is_list_like,
2020
is_numeric_dtype,
21-
is_scalar,
2221
pandas_dtype,
2322
)
2423
from pandas.core.dtypes.dtypes import register_extension_dtype
2524
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
2625
from pandas.core.dtypes.missing import isna
2726

28-
from pandas.core import nanops, ops
29-
from pandas.core.array_algos import masked_reductions
30-
from pandas.core.indexers import check_array_indexer
27+
from pandas.core import ops
3128

3229
from .masked import BaseMaskedArray, BaseMaskedDtype
3330

@@ -347,19 +344,8 @@ def reconstruct(x):
347344
else:
348345
return reconstruct(result)
349346

350-
def __setitem__(self, key, value) -> None:
351-
_is_scalar = is_scalar(value)
352-
if _is_scalar:
353-
value = [value]
354-
value, mask = coerce_to_array(value)
355-
356-
if _is_scalar:
357-
value = value[0]
358-
mask = mask[0]
359-
360-
key = check_array_indexer(self, key)
361-
self._data[key] = value
362-
self._mask[key] = mask
347+
def _coerce_to_array(self, value) -> Tuple[np.ndarray, np.ndarray]:
348+
return coerce_to_array(value)
363349

364350
def astype(self, dtype, copy: bool = True) -> ArrayLike:
365351
"""
@@ -670,24 +656,7 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
670656
if name in {"any", "all"}:
671657
return getattr(self, name)(skipna=skipna, **kwargs)
672658

673-
data = self._data
674-
mask = self._mask
675-
676-
if name in {"sum", "prod", "min", "max"}:
677-
op = getattr(masked_reductions, name)
678-
return op(data, mask, skipna=skipna, **kwargs)
679-
680-
# coerce to a nan-aware float if needed
681-
if self._hasna:
682-
data = self.to_numpy("float64", na_value=np.nan)
683-
684-
op = getattr(nanops, "nan" + name)
685-
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)
686-
687-
if np.isnan(result):
688-
return libmissing.NA
689-
690-
return result
659+
return super()._reduce(name, skipna, **kwargs)
691660

692661
def _maybe_mask_result(self, result, mask, other, op_name: str):
693662
"""

pandas/core/arrays/integer.py

+3-37
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
is_integer_dtype,
2121
is_list_like,
2222
is_object_dtype,
23-
is_scalar,
2423
pandas_dtype,
2524
)
2625
from pandas.core.dtypes.dtypes import register_extension_dtype
2726
from pandas.core.dtypes.missing import isna
2827

29-
from pandas.core import nanops, ops
28+
from pandas.core import ops
3029
from pandas.core.array_algos import masked_reductions
31-
from pandas.core.indexers import check_array_indexer
3230
from pandas.core.ops import invalid_comparison
3331
from pandas.core.ops.common import unpack_zerodim_and_defer
3432
from pandas.core.tools.numeric import to_numeric
@@ -417,19 +415,8 @@ def reconstruct(x):
417415
else:
418416
return reconstruct(result)
419417

420-
def __setitem__(self, key, value) -> None:
421-
_is_scalar = is_scalar(value)
422-
if _is_scalar:
423-
value = [value]
424-
value, mask = coerce_to_array(value, dtype=self.dtype)
425-
426-
if _is_scalar:
427-
value = value[0]
428-
mask = mask[0]
429-
430-
key = check_array_indexer(self, key)
431-
self._data[key] = value
432-
self._mask[key] = mask
418+
def _coerce_to_array(self, value) -> Tuple[np.ndarray, np.ndarray]:
419+
return coerce_to_array(value, dtype=self.dtype)
433420

434421
def astype(self, dtype, copy: bool = True) -> ArrayLike:
435422
"""
@@ -553,27 +540,6 @@ def cmp_method(self, other):
553540
name = f"__{op.__name__}__"
554541
return set_function_name(cmp_method, name, cls)
555542

556-
def _reduce(self, name: str, skipna: bool = True, **kwargs):
557-
data = self._data
558-
mask = self._mask
559-
560-
if name in {"sum", "prod", "min", "max"}:
561-
op = getattr(masked_reductions, name)
562-
return op(data, mask, skipna=skipna, **kwargs)
563-
564-
# coerce to a nan-aware float if needed
565-
# (we explicitly use NaN within reductions)
566-
if self._hasna:
567-
data = self.to_numpy("float64", na_value=np.nan)
568-
569-
op = getattr(nanops, "nan" + name)
570-
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)
571-
572-
if np.isnan(result):
573-
return libmissing.NA
574-
575-
return result
576-
577543
def sum(self, skipna=True, min_count=0, **kwargs):
578544
nv.validate_sum((), kwargs)
579545
result = masked_reductions.sum(

pandas/core/arrays/masked.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@
88
from pandas.util._decorators import doc
99

1010
from pandas.core.dtypes.base import ExtensionDtype
11-
from pandas.core.dtypes.common import is_integer, is_object_dtype, is_string_dtype
11+
from pandas.core.dtypes.common import (
12+
is_integer,
13+
is_object_dtype,
14+
is_scalar,
15+
is_string_dtype,
16+
)
1217
from pandas.core.dtypes.missing import isna, notna
1318

19+
from pandas.core import nanops
1420
from pandas.core.algorithms import _factorize_array, take
21+
from pandas.core.array_algos import masked_reductions
1522
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
1623
from pandas.core.indexers import check_array_indexer
1724

@@ -77,6 +84,23 @@ def __getitem__(self, item):
7784

7885
return type(self)(self._data[item], self._mask[item])
7986

87+
def _coerce_to_array(self, values) -> Tuple[np.ndarray, np.ndarray]:
88+
raise AbstractMethodError(self)
89+
90+
def __setitem__(self, key, value) -> None:
91+
_is_scalar = is_scalar(value)
92+
if _is_scalar:
93+
value = [value]
94+
value, mask = self._coerce_to_array(value)
95+
96+
if _is_scalar:
97+
value = value[0]
98+
mask = mask[0]
99+
100+
key = check_array_indexer(self, key)
101+
self._data[key] = value
102+
self._mask[key] = mask
103+
80104
def __iter__(self):
81105
for i in range(len(self)):
82106
if self._mask[i]:
@@ -305,3 +329,24 @@ def value_counts(self, dropna: bool = True) -> "Series":
305329
counts = IntegerArray(counts, mask)
306330

307331
return Series(counts, index=index)
332+
333+
def _reduce(self, name: str, skipna: bool = True, **kwargs):
334+
data = self._data
335+
mask = self._mask
336+
337+
if name in {"sum", "prod", "min", "max"}:
338+
op = getattr(masked_reductions, name)
339+
return op(data, mask, skipna=skipna, **kwargs)
340+
341+
# coerce to a nan-aware float if needed
342+
# (we explicitly use NaN within reductions)
343+
if self._hasna:
344+
data = self.to_numpy("float64", na_value=np.nan)
345+
346+
op = getattr(nanops, "nan" + name)
347+
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)
348+
349+
if np.isnan(result):
350+
return libmissing.NA
351+
352+
return result

0 commit comments

Comments
 (0)