Skip to content

Commit 601f227

Browse files
authored
REF/PERF: ArrowExtensionArray.__setitem__ (#50632)
* REF/PERF: ArrowExtensionArray.__setitem__ * update asv * whatsnew * fixes * fix min versions * fix min versions * more min version fixes
1 parent 3d89931 commit 601f227

File tree

6 files changed

+140
-143
lines changed

6 files changed

+140
-143
lines changed

asv_bench/benchmarks/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def time_setitem(self, multiple_chunks):
9393
self.array[i] = "foo"
9494

9595
def time_setitem_list(self, multiple_chunks):
96-
indexer = list(range(0, 50)) + list(range(-50, 0))
96+
indexer = list(range(0, 50)) + list(range(-1000, 0, 50))
9797
self.array[indexer] = ["foo"] * len(indexer)
9898

9999
def time_setitem_slice(self, multiple_chunks):

doc/source/whatsnew/v2.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ Performance improvements
850850
- Performance improvement for :class:`~arrays.StringArray` constructor passing a numpy array with type ``np.str_`` (:issue:`49109`)
851851
- Performance improvement in :meth:`~arrays.IntervalArray.from_tuples` (:issue:`50620`)
852852
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.factorize` (:issue:`49177`)
853-
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.__setitem__` when key is a null slice (:issue:`50248`)
853+
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.__setitem__` (:issue:`50248`, :issue:`50632`)
854854
- Performance improvement in :class:`~arrays.ArrowExtensionArray` comparison methods when array contains NA (:issue:`50524`)
855855
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.to_numpy` (:issue:`49973`)
856856
- Performance improvement when parsing strings to :class:`BooleanDtype` (:issue:`50613`)

pandas/core/arrays/arrow/array.py

+133-135
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
Iterator,
2020
NpDtype,
2121
PositionalIndexer,
22+
Scalar,
2223
SortKind,
2324
TakeIndexer,
2425
npt,
2526
)
2627
from pandas.compat import (
2728
pa_version_under6p0,
2829
pa_version_under7p0,
30+
pa_version_under8p0,
2931
pa_version_under9p0,
3032
)
3133
from pandas.util._decorators import doc
@@ -36,6 +38,7 @@
3638
is_bool_dtype,
3739
is_integer,
3840
is_integer_dtype,
41+
is_list_like,
3942
is_object_dtype,
4043
is_scalar,
4144
)
@@ -1056,76 +1059,56 @@ def __setitem__(self, key, value) -> None:
10561059
key = check_array_indexer(self, key)
10571060
value = self._maybe_convert_setitem_value(value)
10581061

1059-
# fast path (GH50248)
10601062
if com.is_null_slice(key):
1061-
if is_scalar(value):
1062-
fill_value = pa.scalar(value, type=self._data.type, from_pandas=True)
1063-
try:
1064-
self._data = pc.if_else(True, fill_value, self._data)
1065-
return
1066-
except pa.ArrowNotImplementedError:
1067-
# ArrowNotImplementedError: Function 'if_else' has no kernel
1068-
# matching input types (bool, duration[ns], duration[ns])
1069-
# TODO: remove try/except wrapper if/when pyarrow implements
1070-
# a kernel for duration types.
1071-
pass
1072-
elif len(value) == len(self):
1073-
if isinstance(value, type(self)) and value.dtype == self.dtype:
1074-
self._data = value._data
1075-
else:
1076-
arr = pa.array(value, type=self._data.type, from_pandas=True)
1077-
self._data = pa.chunked_array([arr])
1078-
return
1079-
1080-
indices = self._indexing_key_to_indices(key)
1081-
argsort = np.argsort(indices)
1082-
indices = indices[argsort]
1083-
1084-
if is_scalar(value):
1085-
value = np.broadcast_to(value, len(self))
1086-
elif len(indices) != len(value):
1087-
raise ValueError("Length of indexer and values mismatch")
1088-
else:
1089-
value = np.asarray(value)[argsort]
1063+
# fast path (GH50248)
1064+
data = self._if_else(True, value, self._data)
10901065

1091-
self._data = self._set_via_chunk_iteration(indices=indices, value=value)
1066+
elif is_integer(key):
1067+
# fast path
1068+
key = cast(int, key)
1069+
n = len(self)
1070+
if key < 0:
1071+
key += n
1072+
if not 0 <= key < n:
1073+
raise IndexError(
1074+
f"index {key} is out of bounds for axis 0 with size {n}"
1075+
)
1076+
if is_list_like(value):
1077+
raise ValueError("Length of indexer and values mismatch")
1078+
elif isinstance(value, pa.Scalar):
1079+
value = value.as_py()
1080+
chunks = [
1081+
*self._data[:key].chunks,
1082+
pa.array([value], type=self._data.type, from_pandas=True),
1083+
*self._data[key + 1 :].chunks,
1084+
]
1085+
data = pa.chunked_array(chunks).combine_chunks()
10921086

1093-
def _indexing_key_to_indices(
1094-
self, key: int | slice | np.ndarray
1095-
) -> npt.NDArray[np.intp]:
1096-
"""
1097-
Convert indexing key for self into positional indices.
1087+
elif is_bool_dtype(key):
1088+
key = np.asarray(key, dtype=np.bool_)
1089+
data = self._replace_with_mask(self._data, key, value)
10981090

1099-
Parameters
1100-
----------
1101-
key : int | slice | np.ndarray
1091+
elif is_scalar(value) or isinstance(value, pa.Scalar):
1092+
mask = np.zeros(len(self), dtype=np.bool_)
1093+
mask[key] = True
1094+
data = self._if_else(mask, value, self._data)
11021095

1103-
Returns
1104-
-------
1105-
npt.NDArray[np.intp]
1106-
"""
1107-
n = len(self)
1108-
if isinstance(key, slice):
1109-
indices = np.arange(n)[key]
1110-
elif is_integer(key):
1111-
# error: Invalid index type "List[Union[int, ndarray[Any, Any]]]"
1112-
# for "ndarray[Any, dtype[signedinteger[Any]]]"; expected type
1113-
# "Union[SupportsIndex, _SupportsArray[dtype[Union[bool_,
1114-
# integer[Any]]]], _NestedSequence[_SupportsArray[dtype[Union
1115-
# [bool_, integer[Any]]]]], _NestedSequence[Union[bool, int]]
1116-
# , Tuple[Union[SupportsIndex, _SupportsArray[dtype[Union[bool_
1117-
# , integer[Any]]]], _NestedSequence[_SupportsArray[dtype[Union
1118-
# [bool_, integer[Any]]]]], _NestedSequence[Union[bool, int]]], ...]]"
1119-
indices = np.arange(n)[[key]] # type: ignore[index]
1120-
elif is_bool_dtype(key):
1121-
key = np.asarray(key)
1122-
if len(key) != n:
1123-
raise ValueError("Length of indexer and values mismatch")
1124-
indices = key.nonzero()[0]
11251096
else:
1126-
key = np.asarray(key)
1127-
indices = np.arange(n)[key]
1128-
return indices
1097+
indices = np.arange(len(self))[key]
1098+
if len(indices) != len(value):
1099+
raise ValueError("Length of indexer and values mismatch")
1100+
if len(indices) == 0:
1101+
return
1102+
argsort = np.argsort(indices)
1103+
indices = indices[argsort]
1104+
value = value.take(argsort)
1105+
mask = np.zeros(len(self), dtype=np.bool_)
1106+
mask[indices] = True
1107+
data = self._replace_with_mask(self._data, mask, value)
1108+
1109+
if isinstance(data, pa.Array):
1110+
data = pa.chunked_array([data])
1111+
self._data = data
11291112

11301113
def _rank(
11311114
self,
@@ -1241,95 +1224,110 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra
12411224

12421225
def _maybe_convert_setitem_value(self, value):
12431226
"""Maybe convert value to be pyarrow compatible."""
1244-
# TODO: Make more robust like ArrowStringArray._maybe_convert_setitem_value
1227+
if value is None:
1228+
return value
1229+
if isinstance(value, (pa.Scalar, pa.Array, pa.ChunkedArray)):
1230+
return value
1231+
if is_list_like(value):
1232+
pa_box = pa.array
1233+
else:
1234+
pa_box = pa.scalar
1235+
try:
1236+
value = pa_box(value, type=self._data.type, from_pandas=True)
1237+
except pa.ArrowTypeError as err:
1238+
msg = f"Invalid value '{str(value)}' for dtype {self.dtype}"
1239+
raise TypeError(msg) from err
12451240
return value
12461241

1247-
def _set_via_chunk_iteration(
1248-
self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any]
1249-
) -> pa.ChunkedArray:
1242+
@classmethod
1243+
def _if_else(
1244+
cls,
1245+
cond: npt.NDArray[np.bool_] | bool,
1246+
left: ArrayLike | Scalar,
1247+
right: ArrayLike | Scalar,
1248+
):
12501249
"""
1251-
Loop through the array chunks and set the new values while
1252-
leaving the chunking layout unchanged.
1250+
Choose values based on a condition.
1251+
1252+
Analogous to pyarrow.compute.if_else, with logic
1253+
to fallback to numpy for unsupported types.
12531254
12541255
Parameters
12551256
----------
1256-
indices : npt.NDArray[np.intp]
1257-
Position indices for the underlying ChunkedArray.
1258-
1259-
value : ExtensionDtype.type, Sequence[ExtensionDtype.type], or object
1260-
value or values to be set of ``key``.
1257+
cond : npt.NDArray[np.bool_] or bool
1258+
left : ArrayLike | Scalar
1259+
right : ArrayLike | Scalar
12611260
1262-
Notes
1263-
-----
1264-
Assumes that indices is sorted. Caller is responsible for sorting.
1261+
Returns
1262+
-------
1263+
pa.Array
12651264
"""
1266-
new_data = []
1267-
stop = 0
1268-
for chunk in self._data.iterchunks():
1269-
start, stop = stop, stop + len(chunk)
1270-
if len(indices) == 0 or stop <= indices[0]:
1271-
new_data.append(chunk)
1265+
try:
1266+
return pc.if_else(cond, left, right)
1267+
except pa.ArrowNotImplementedError:
1268+
pass
1269+
1270+
def _to_numpy_and_type(value) -> tuple[np.ndarray, pa.DataType | None]:
1271+
if isinstance(value, (pa.Array, pa.ChunkedArray)):
1272+
pa_type = value.type
1273+
elif isinstance(value, pa.Scalar):
1274+
pa_type = value.type
1275+
value = value.as_py()
12721276
else:
1273-
n = int(np.searchsorted(indices, stop, side="left"))
1274-
c_ind = indices[:n] - start
1275-
indices = indices[n:]
1276-
n = len(c_ind)
1277-
c_value, value = value[:n], value[n:]
1278-
new_data.append(self._replace_with_indices(chunk, c_ind, c_value))
1279-
return pa.chunked_array(new_data)
1277+
pa_type = None
1278+
return np.array(value, dtype=object), pa_type
1279+
1280+
left, left_type = _to_numpy_and_type(left)
1281+
right, right_type = _to_numpy_and_type(right)
1282+
pa_type = left_type or right_type
1283+
result = np.where(cond, left, right)
1284+
return pa.array(result, type=pa_type, from_pandas=True)
12801285

12811286
@classmethod
1282-
def _replace_with_indices(
1287+
def _replace_with_mask(
12831288
cls,
1284-
chunk: pa.Array,
1285-
indices: npt.NDArray[np.intp],
1286-
value: npt.NDArray[Any],
1287-
) -> pa.Array:
1289+
values: pa.Array | pa.ChunkedArray,
1290+
mask: npt.NDArray[np.bool_] | bool,
1291+
replacements: ArrayLike | Scalar,
1292+
):
12881293
"""
1289-
Replace items selected with a set of positional indices.
1294+
Replace items selected with a mask.
12901295
1291-
Analogous to pyarrow.compute.replace_with_mask, except that replacement
1292-
positions are identified via indices rather than a mask.
1296+
Analogous to pyarrow.compute.replace_with_mask, with logic
1297+
to fallback to numpy for unsupported types.
12931298
12941299
Parameters
12951300
----------
1296-
chunk : pa.Array
1297-
indices : npt.NDArray[np.intp]
1298-
value : npt.NDArray[Any]
1299-
Replacement value(s).
1301+
values : pa.Array or pa.ChunkedArray
1302+
mask : npt.NDArray[np.bool_] or bool
1303+
replacements : ArrayLike or Scalar
1304+
Replacement value(s)
13001305
13011306
Returns
13021307
-------
1303-
pa.Array
1308+
pa.Array or pa.ChunkedArray
13041309
"""
1305-
n = len(indices)
1306-
1307-
if n == 0:
1308-
return chunk
1309-
1310-
start, stop = indices[[0, -1]]
1311-
1312-
if (stop - start) == (n - 1):
1313-
# fast path for a contiguous set of indices
1314-
arrays = [
1315-
chunk[:start],
1316-
pa.array(value, type=chunk.type, from_pandas=True),
1317-
chunk[stop + 1 :],
1318-
]
1319-
arrays = [arr for arr in arrays if len(arr)]
1320-
if len(arrays) == 1:
1321-
return arrays[0]
1322-
return pa.concat_arrays(arrays)
1323-
1324-
mask = np.zeros(len(chunk), dtype=np.bool_)
1325-
mask[indices] = True
1326-
1327-
if pa_version_under6p0:
1328-
arr = chunk.to_numpy(zero_copy_only=False)
1329-
arr[mask] = value
1330-
return pa.array(arr, type=chunk.type)
1331-
1332-
if isna(value).all():
1333-
return pc.if_else(mask, None, chunk)
1334-
1335-
return pc.replace_with_mask(chunk, mask, value)
1310+
if isinstance(replacements, pa.ChunkedArray):
1311+
# replacements must be array or scalar, not ChunkedArray
1312+
replacements = replacements.combine_chunks()
1313+
if pa_version_under8p0:
1314+
# pc.replace_with_mask seems to be a bit unreliable for versions < 8.0:
1315+
# version <= 7: segfaults with various types
1316+
# version <= 6: fails to replace nulls
1317+
if isinstance(replacements, pa.Array):
1318+
indices = np.full(len(values), None)
1319+
indices[mask] = np.arange(len(replacements))
1320+
indices = pa.array(indices, type=pa.int64())
1321+
replacements = replacements.take(indices)
1322+
return cls._if_else(mask, replacements, values)
1323+
try:
1324+
return pc.replace_with_mask(values, mask, replacements)
1325+
except pa.ArrowNotImplementedError:
1326+
pass
1327+
if isinstance(replacements, pa.Array):
1328+
replacements = np.array(replacements, dtype=object)
1329+
elif isinstance(replacements, pa.Scalar):
1330+
replacements = replacements.as_py()
1331+
result = np.array(values, dtype=object)
1332+
result[mask] = replacements
1333+
return pa.array(result, type=values.type, from_pandas=True)

pandas/core/arrays/string_arrow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _maybe_convert_setitem_value(self, value):
170170
for v in value:
171171
if not (v is None or isinstance(v, str)):
172172
raise ValueError("Scalar must be NA or str")
173-
return value
173+
return super()._maybe_convert_setitem_value(value)
174174

175175
def isin(self, values) -> npt.NDArray[np.bool_]:
176176
value_set = [

pandas/tests/arrays/string_/test_string_arrow.py

-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def test_setitem(multiple_chunks, key, value, expected):
172172

173173
result[key] = value
174174
tm.assert_equal(result, expected)
175-
assert result._data.num_chunks == expected._data.num_chunks
176175

177176

178177
@skip_if_no_pyarrow

pandas/tests/extension/test_arrow.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1526,8 +1526,8 @@ def test_setitem_invalid_dtype(data):
15261526
pa_type = data._data.type
15271527
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
15281528
fill_value = 123
1529-
err = pa.ArrowTypeError
1530-
msg = "Expected bytes"
1529+
err = TypeError
1530+
msg = "Invalid value '123' for dtype"
15311531
elif (
15321532
pa.types.is_integer(pa_type)
15331533
or pa.types.is_floating(pa_type)
@@ -1538,8 +1538,8 @@ def test_setitem_invalid_dtype(data):
15381538
msg = "Could not convert"
15391539
else:
15401540
fill_value = "foo"
1541-
err = pa.ArrowTypeError
1542-
msg = "cannot be converted"
1541+
err = TypeError
1542+
msg = "Invalid value 'foo' for dtype"
15431543
with pytest.raises(err, match=msg):
15441544
data[:] = fill_value
15451545

0 commit comments

Comments
 (0)