Skip to content

Commit cf5d6a2

Browse files
authored
REF/TYP: define methods non-dynamically for SparseArray (#36943)
1 parent 9425d7c commit cf5d6a2

File tree

2 files changed

+37
-63
lines changed

2 files changed

+37
-63
lines changed

pandas/core/arrays/sparse/array.py

+37-60
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna
4141

4242
import pandas.core.algorithms as algos
43+
from pandas.core.arraylike import OpsMixin
4344
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
4445
from pandas.core.arrays.sparse.dtype import SparseDtype
4546
from pandas.core.base import PandasObject
@@ -195,7 +196,7 @@ def _wrap_result(name, data, sparse_index, fill_value, dtype=None):
195196
)
196197

197198

198-
class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin):
199+
class SparseArray(OpsMixin, PandasObject, ExtensionArray, ExtensionOpsMixin):
199200
"""
200201
An ExtensionArray for storing sparse data.
201202
@@ -762,8 +763,6 @@ def value_counts(self, dropna=True):
762763
# --------
763764

764765
def __getitem__(self, key):
765-
# avoid mypy issues when importing at the top-level
766-
from pandas.core.indexing import check_bool_indexer
767766

768767
if isinstance(key, tuple):
769768
if len(key) > 1:
@@ -796,7 +795,6 @@ def __getitem__(self, key):
796795
key = check_array_indexer(self, key)
797796

798797
if com.is_bool_indexer(key):
799-
key = check_bool_indexer(self, key)
800798

801799
return self.take(np.arange(len(key), dtype=np.int32)[key])
802800
elif hasattr(key, "__len__"):
@@ -1390,17 +1388,6 @@ def __abs__(self):
13901388
# Ops
13911389
# ------------------------------------------------------------------------
13921390

1393-
@classmethod
1394-
def _create_unary_method(cls, op) -> Callable[["SparseArray"], "SparseArray"]:
1395-
def sparse_unary_method(self) -> "SparseArray":
1396-
fill_value = op(np.array(self.fill_value)).item()
1397-
values = op(self.sp_values)
1398-
dtype = SparseDtype(values.dtype, fill_value)
1399-
return cls._simple_new(values, self.sp_index, dtype)
1400-
1401-
name = f"__{op.__name__}__"
1402-
return compat.set_function_name(sparse_unary_method, name, cls)
1403-
14041391
@classmethod
14051392
def _create_arithmetic_method(cls, op):
14061393
op_name = op.__name__
@@ -1444,56 +1431,48 @@ def sparse_arithmetic_method(self, other):
14441431
name = f"__{op.__name__}__"
14451432
return compat.set_function_name(sparse_arithmetic_method, name, cls)
14461433

1447-
@classmethod
1448-
def _create_comparison_method(cls, op):
1449-
op_name = op.__name__
1450-
if op_name in {"and_", "or_"}:
1451-
op_name = op_name[:-1]
1434+
def _cmp_method(self, other, op) -> "SparseArray":
1435+
if not is_scalar(other) and not isinstance(other, type(self)):
1436+
# convert list-like to ndarray
1437+
other = np.asarray(other)
14521438

1453-
@unpack_zerodim_and_defer(op_name)
1454-
def cmp_method(self, other):
1455-
1456-
if not is_scalar(other) and not isinstance(other, type(self)):
1457-
# convert list-like to ndarray
1458-
other = np.asarray(other)
1439+
if isinstance(other, np.ndarray):
1440+
# TODO: make this more flexible than just ndarray...
1441+
if len(self) != len(other):
1442+
raise AssertionError(f"length mismatch: {len(self)} vs. {len(other)}")
1443+
other = SparseArray(other, fill_value=self.fill_value)
14591444

1460-
if isinstance(other, np.ndarray):
1461-
# TODO: make this more flexible than just ndarray...
1462-
if len(self) != len(other):
1463-
raise AssertionError(
1464-
f"length mismatch: {len(self)} vs. {len(other)}"
1465-
)
1466-
other = SparseArray(other, fill_value=self.fill_value)
1445+
if isinstance(other, SparseArray):
1446+
op_name = op.__name__.strip("_")
1447+
return _sparse_array_op(self, other, op, op_name)
1448+
else:
1449+
with np.errstate(all="ignore"):
1450+
fill_value = op(self.fill_value, other)
1451+
result = op(self.sp_values, other)
1452+
1453+
return type(self)(
1454+
result,
1455+
sparse_index=self.sp_index,
1456+
fill_value=fill_value,
1457+
dtype=np.bool_,
1458+
)
14671459

1468-
if isinstance(other, SparseArray):
1469-
return _sparse_array_op(self, other, op, op_name)
1470-
else:
1471-
with np.errstate(all="ignore"):
1472-
fill_value = op(self.fill_value, other)
1473-
result = op(self.sp_values, other)
1460+
_logical_method = _cmp_method
14741461

1475-
return type(self)(
1476-
result,
1477-
sparse_index=self.sp_index,
1478-
fill_value=fill_value,
1479-
dtype=np.bool_,
1480-
)
1462+
def _unary_method(self, op) -> "SparseArray":
1463+
fill_value = op(np.array(self.fill_value)).item()
1464+
values = op(self.sp_values)
1465+
dtype = SparseDtype(values.dtype, fill_value)
1466+
return type(self)._simple_new(values, self.sp_index, dtype)
14811467

1482-
name = f"__{op.__name__}__"
1483-
return compat.set_function_name(cmp_method, name, cls)
1468+
def __pos__(self) -> "SparseArray":
1469+
return self._unary_method(operator.pos)
14841470

1485-
@classmethod
1486-
def _add_unary_ops(cls):
1487-
cls.__pos__ = cls._create_unary_method(operator.pos)
1488-
cls.__neg__ = cls._create_unary_method(operator.neg)
1489-
cls.__invert__ = cls._create_unary_method(operator.invert)
1471+
def __neg__(self) -> "SparseArray":
1472+
return self._unary_method(operator.neg)
14901473

1491-
@classmethod
1492-
def _add_comparison_ops(cls):
1493-
cls.__and__ = cls._create_comparison_method(operator.and_)
1494-
cls.__or__ = cls._create_comparison_method(operator.or_)
1495-
cls.__xor__ = cls._create_arithmetic_method(operator.xor)
1496-
super()._add_comparison_ops()
1474+
def __invert__(self) -> "SparseArray":
1475+
return self._unary_method(operator.invert)
14971476

14981477
# ----------
14991478
# Formatting
@@ -1511,8 +1490,6 @@ def _formatter(self, boxed=False):
15111490

15121491

15131492
SparseArray._add_arithmetic_ops()
1514-
SparseArray._add_comparison_ops()
1515-
SparseArray._add_unary_ops()
15161493

15171494

15181495
def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None):

setup.cfg

-3
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,6 @@ check_untyped_defs=False
142142
[mypy-pandas.core.arrays.datetimelike]
143143
check_untyped_defs=False
144144

145-
[mypy-pandas.core.arrays.sparse.array]
146-
check_untyped_defs=False
147-
148145
[mypy-pandas.core.arrays.string_]
149146
check_untyped_defs=False
150147

0 commit comments

Comments
 (0)