Skip to content

Commit fb2b5fe

Browse files
committed
PERF: improve efficiency of BaseMaskedArray.__setitem__
This somewhat deals with pandas-dev#44172, though that won't be fully resolved until 2D `ExtensionArray`s are supported (per the comments there).
1 parent 040f236 commit fb2b5fe

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

pandas/core/apply.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def frame_apply(
8585
args=None,
8686
kwargs=None,
8787
) -> FrameApply:
88-
"""construct and return a row or column based frame apply object"""
88+
"""Construct and return a row- or column-based frame apply object."""
8989
axis = obj._get_axis_number(axis)
9090
klass: type[FrameApply]
9191
if axis == 0:
@@ -693,7 +693,7 @@ def dtypes(self) -> Series:
693693
return self.obj.dtypes
694694

695695
def apply(self) -> DataFrame | Series:
696-
"""compute the results"""
696+
"""Compute the results."""
697697
# dispatch to agg
698698
if is_list_like(self.f):
699699
return self.apply_multiple()
@@ -1011,7 +1011,7 @@ def result_columns(self) -> Index:
10111011
def wrap_results_for_axis(
10121012
self, results: ResType, res_index: Index
10131013
) -> DataFrame | Series:
1014-
"""return the results for the columns"""
1014+
"""Return the results for the columns."""
10151015
result: DataFrame | Series
10161016

10171017
# we have requested to expand

pandas/core/arrays/boolean.py

+3
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def map_string(s):
365365
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
366366
return coerce_to_array(value)
367367

368+
def _validate_setitem_value(self, value):
369+
return lib.is_bool(value)
370+
368371
@overload
369372
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
370373
...

pandas/core/arrays/floating.py

+3
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ def _from_sequence_of_strings(
270270
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
271271
return coerce_to_array(value, dtype=self.dtype)
272272

273+
def _validate_setitem_value(self, value):
274+
return lib.is_float(value)
275+
273276
@overload
274277
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
275278
...

pandas/core/arrays/integer.py

+3
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ def _from_sequence_of_strings(
338338
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
339339
return coerce_to_array(value, dtype=self.dtype)
340340

341+
def _validate_setitem_value(self, value):
342+
return lib.is_integer(value)
343+
341344
@overload
342345
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
343346
...

pandas/core/arrays/masked.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from pandas.core.dtypes.inference import is_array_like
5151
from pandas.core.dtypes.missing import (
5252
array_equivalent,
53+
is_valid_na_for_dtype,
5354
isna,
5455
notna,
5556
)
@@ -82,7 +83,7 @@
8283

8384
class BaseMaskedDtype(ExtensionDtype):
8485
"""
85-
Base class for dtypes for BasedMaskedArray subclasses.
86+
Base class for dtypes for BaseMaskedArray subclasses.
8687
"""
8788

8889
name: str
@@ -213,19 +214,23 @@ def fillna(
213214
def _coerce_to_array(self, values) -> tuple[np.ndarray, np.ndarray]:
214215
raise AbstractMethodError(self)
215216

216-
def __setitem__(self, key, value) -> None:
217-
_is_scalar = is_scalar(value)
218-
if _is_scalar:
219-
value = [value]
220-
value, mask = self._coerce_to_array(value)
221-
222-
if _is_scalar:
223-
value = value[0]
224-
mask = mask[0]
217+
def _validate_setitem_value(self, value) -> bool:
218+
raise AbstractMethodError(self)
225219

220+
def __setitem__(self, key, value) -> None:
226221
key = check_array_indexer(self, key)
227-
self._data[key] = value
228-
self._mask[key] = mask
222+
if is_scalar(value):
223+
if self._validate_setitem_value(value):
224+
self._data[key] = value
225+
self._mask[key] = False
226+
elif isna(value) and is_valid_na_for_dtype(value):
227+
self._mask[key] = True
228+
else:
229+
raise TypeError(f"Invalid value '{value}' for dtype {self.dtype}")
230+
else:
231+
value, mask = self._coerce_to_array(value)
232+
self._data[key] = value
233+
self._mask[key] = mask
229234

230235
def __iter__(self):
231236
if self.ndim == 1:

0 commit comments

Comments
 (0)