Skip to content

Commit fd645db

Browse files
committed
PERF: improve efficiency of BaseMaskedArray.__setitem__
This somewhat deals with #44172, though that won't be fully resolved until 2D `ExtensionArray`s are supported (per the comments there).
1 parent 7c00e0c commit fd645db

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

pandas/core/apply.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def frame_apply(
8585
args=None,
8686
kwargs=None,
8787
) -> FrameApply:
88-
"""construct and return a row or column based frame apply object"""
88+
"""Construct and return a row- or column-based frame apply object."""
8989
axis = obj._get_axis_number(axis)
9090
klass: type[FrameApply]
9191
if axis == 0:
@@ -693,7 +693,7 @@ def dtypes(self) -> Series:
693693
return self.obj.dtypes
694694

695695
def apply(self) -> DataFrame | Series:
696-
"""compute the results"""
696+
"""Compute the results."""
697697
# dispatch to agg
698698
if is_list_like(self.f):
699699
return self.apply_multiple()
@@ -1011,7 +1011,7 @@ def result_columns(self) -> Index:
10111011
def wrap_results_for_axis(
10121012
self, results: ResType, res_index: Index
10131013
) -> DataFrame | Series:
1014-
"""return the results for the columns"""
1014+
"""Return the results for the columns."""
10151015
result: DataFrame | Series
10161016

10171017
# we have requested to expand

pandas/core/arrays/boolean.py

+3
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def map_string(s):
367367
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
368368
return coerce_to_array(value)
369369

370+
def _validate_setitem_value(self, value):
371+
return lib.is_bool(value)
372+
370373
@overload
371374
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
372375
...

pandas/core/arrays/floating.py

+3
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def _from_sequence_of_strings(
278278
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
279279
return coerce_to_array(value, dtype=self.dtype)
280280

281+
def _validate_setitem_value(self, value):
282+
return lib.is_float(value)
283+
281284
@overload
282285
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
283286
...

pandas/core/arrays/integer.py

+3
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def _from_sequence_of_strings(
345345
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
346346
return coerce_to_array(value, dtype=self.dtype)
347347

348+
def _validate_setitem_value(self, value):
349+
return lib.is_integer(value)
350+
348351
@overload
349352
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
350353
...

pandas/core/arrays/masked.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from pandas.core.dtypes.inference import is_array_like
4949
from pandas.core.dtypes.missing import (
50+
is_valid_na_for_dtype,
5051
isna,
5152
notna,
5253
)
@@ -77,7 +78,7 @@
7778

7879
class BaseMaskedDtype(ExtensionDtype):
7980
"""
80-
Base class for dtypes for BasedMaskedArray subclasses.
81+
Base class for dtypes for BaseMaskedArray subclasses.
8182
"""
8283

8384
name: str
@@ -208,19 +209,23 @@ def fillna(
208209
def _coerce_to_array(self, values) -> tuple[np.ndarray, np.ndarray]:
209210
raise AbstractMethodError(self)
210211

211-
def __setitem__(self, key, value) -> None:
212-
_is_scalar = is_scalar(value)
213-
if _is_scalar:
214-
value = [value]
215-
value, mask = self._coerce_to_array(value)
216-
217-
if _is_scalar:
218-
value = value[0]
219-
mask = mask[0]
212+
def _validate_setitem_value(self, value) -> bool:
213+
raise AbstractMethodError(self)
220214

215+
def __setitem__(self, key, value) -> None:
221216
key = check_array_indexer(self, key)
222-
self._data[key] = value
223-
self._mask[key] = mask
217+
if is_scalar(value):
218+
if self._validate_setitem_value(value):
219+
self._data[key] = value
220+
self._mask[key] = False
221+
elif isna(value) and is_valid_na_for_dtype(value):
222+
self._mask[key] = True
223+
else:
224+
raise TypeError(f"Invalid value '{value}' for dtype {self.dtype}")
225+
else:
226+
value, mask = self._coerce_to_array(value)
227+
self._data[key] = value
228+
self._mask[key] = mask
224229

225230
def __iter__(self):
226231
if self.ndim == 1:

0 commit comments

Comments
 (0)