Skip to content

Commit d468b36

Browse files
committed
PERF: improve efficiency of BaseMaskedArray.__setitem__
This somewhat deals with #44172, though that won't be fully resolved until 2D `ExtensionArray`s are supported (per the comments there).
1 parent 7c00e0c commit d468b36

File tree

5 files changed

+31
-15
lines changed

5 files changed

+31
-15
lines changed

pandas/core/apply.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def frame_apply(
8585
args=None,
8686
kwargs=None,
8787
) -> FrameApply:
88-
"""construct and return a row or column based frame apply object"""
88+
"""Construct and return a row- or column-based frame apply object."""
8989
axis = obj._get_axis_number(axis)
9090
klass: type[FrameApply]
9191
if axis == 0:
@@ -693,7 +693,7 @@ def dtypes(self) -> Series:
693693
return self.obj.dtypes
694694

695695
def apply(self) -> DataFrame | Series:
696-
"""compute the results"""
696+
"""Compute the results."""
697697
# dispatch to agg
698698
if is_list_like(self.f):
699699
return self.apply_multiple()
@@ -1011,7 +1011,7 @@ def result_columns(self) -> Index:
10111011
def wrap_results_for_axis(
10121012
self, results: ResType, res_index: Index
10131013
) -> DataFrame | Series:
1014-
"""return the results for the columns"""
1014+
"""Return the results for the columns."""
10151015
result: DataFrame | Series
10161016

10171017
# we have requested to expand

pandas/core/arrays/boolean.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
BaseMaskedArray,
4545
BaseMaskedDtype,
4646
)
47+
from pandas.core.indexers import check_array_indexer
4748

4849
if TYPE_CHECKING:
4950
import pyarrow
@@ -367,6 +368,9 @@ def map_string(s):
367368
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
368369
return coerce_to_array(value)
369370

371+
def _validate_setitem_value(self, value):
372+
return lib.is_bool(value)
373+
370374
@overload
371375
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
372376
...

pandas/core/arrays/floating.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
NumericArray,
4040
NumericDtype,
4141
)
42+
from pandas.core.indexers import check_array_indexer
4243
from pandas.core.ops import invalid_comparison
4344
from pandas.core.tools.numeric import to_numeric
4445

@@ -278,6 +279,9 @@ def _from_sequence_of_strings(
278279
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
279280
return coerce_to_array(value, dtype=self.dtype)
280281

282+
def _validate_setitem_value(self, value):
283+
return lib.is_float(value)
284+
281285
@overload
282286
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
283287
...

pandas/core/arrays/integer.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
NumericArray,
4747
NumericDtype,
4848
)
49+
from pandas.core.indexers import check_array_indexer
4950
from pandas.core.ops import invalid_comparison
5051
from pandas.core.tools.numeric import to_numeric
5152

@@ -345,6 +346,9 @@ def _from_sequence_of_strings(
345346
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
346347
return coerce_to_array(value, dtype=self.dtype)
347348

349+
def _validate_setitem_value(self, value):
350+
return lib.is_integer(value)
351+
348352
@overload
349353
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
350354
...

pandas/core/arrays/masked.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777

7878
class BaseMaskedDtype(ExtensionDtype):
7979
"""
80-
Base class for dtypes for BasedMaskedArray subclasses.
80+
Base class for dtypes for BaseMaskedArray subclasses.
8181
"""
8282

8383
name: str
@@ -208,19 +208,23 @@ def fillna(
208208
def _coerce_to_array(self, values) -> tuple[np.ndarray, np.ndarray]:
209209
raise AbstractMethodError(self)
210210

211-
def __setitem__(self, key, value) -> None:
212-
_is_scalar = is_scalar(value)
213-
if _is_scalar:
214-
value = [value]
215-
value, mask = self._coerce_to_array(value)
216-
217-
if _is_scalar:
218-
value = value[0]
219-
mask = mask[0]
211+
def _validate_setitem_value(self, value) -> bool:
212+
raise AbstractMethodError(self)
220213

214+
def __setitem__(self, key, value) -> None:
221215
key = check_array_indexer(self, key)
222-
self._data[key] = value
223-
self._mask[key] = mask
216+
if is_scalar(value):
217+
if self._validate_setitem_value(value):
218+
self._data[key] = value
219+
self._mask[key] = False
220+
elif value is libmissing.NA:
221+
self._mask[key] = True
222+
else:
223+
raise TypeError(f"Invalid value '{value}'")
224+
else:
225+
value, mask = self._coerce_to_array(value)
226+
self._data[key] = value
227+
self._mask[key] = mask
224228

225229
def __iter__(self):
226230
if self.ndim == 1:

0 commit comments

Comments
 (0)