Skip to content

Commit ec3eedd

Browse files
authored
REF/PERF: ArrowStringArray.__setitem__ (#46400)
1 parent abfb4b7 commit ec3eedd

File tree

5 files changed

+209
-34
lines changed

5 files changed

+209
-34
lines changed

asv_bench/benchmarks/array.py

+31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pandas as pd
44

5+
from .pandas_vb_common import tm
6+
57

68
class BooleanArray:
79
def setup(self):
@@ -39,3 +41,32 @@ def time_constructor(self):
3941

4042
def time_from_integer_array(self):
4143
pd.array(self.values_integer, dtype="Int64")
44+
45+
46+
class ArrowStringArray:
47+
48+
params = [False, True]
49+
param_names = ["multiple_chunks"]
50+
51+
def setup(self, multiple_chunks):
52+
try:
53+
import pyarrow as pa
54+
except ImportError:
55+
raise NotImplementedError
56+
strings = tm.rands_array(3, 10_000)
57+
if multiple_chunks:
58+
chunks = [strings[i : i + 100] for i in range(0, len(strings), 100)]
59+
self.array = pd.arrays.ArrowStringArray(pa.chunked_array(chunks))
60+
else:
61+
self.array = pd.arrays.ArrowStringArray(pa.array(strings))
62+
63+
def time_setitem(self, multiple_chunks):
64+
for i in range(200):
65+
self.array[i] = "foo"
66+
67+
def time_setitem_list(self, multiple_chunks):
68+
indexer = list(range(0, 50)) + list(range(-50, 0))
69+
self.array[indexer] = ["foo"] * len(indexer)
70+
71+
def time_setitem_slice(self, multiple_chunks):
72+
self.array[::10] = "foo"

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ Performance improvements
323323
- Performance improvement in :func:`merge` when left and/or right are empty (:issue:`45838`)
324324
- Performance improvement in :meth:`DataFrame.join` when left and/or right are empty (:issue:`46015`)
325325
- Performance improvement in :meth:`DataFrame.reindex` and :meth:`Series.reindex` when target is a :class:`MultiIndex` (:issue:`46235`)
326+
- Performance improvement when setting values in a pyarrow backed string array (:issue:`46400`)
326327
- Performance improvement in :func:`factorize` (:issue:`46109`)
327328
- Performance improvement in :class:`DataFrame` and :class:`Series` constructors for extension dtype scalars (:issue:`45854`)
328329

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
pa_version_under2p0,
2323
pa_version_under3p0,
2424
pa_version_under4p0,
25+
pa_version_under5p0,
2526
)
2627

2728
PY39 = sys.version_info >= (3, 9)
@@ -148,4 +149,5 @@ def get_lzma_file():
148149
"pa_version_under2p0",
149150
"pa_version_under3p0",
150151
"pa_version_under4p0",
152+
"pa_version_under5p0",
151153
]

pandas/core/arrays/string_arrow.py

+110-34
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
TYPE_CHECKING,
77
Any,
88
Union,
9-
cast,
109
overload,
1110
)
1211

@@ -31,6 +30,7 @@
3130
pa_version_under2p0,
3231
pa_version_under3p0,
3332
pa_version_under4p0,
33+
pa_version_under5p0,
3434
)
3535
from pandas.util._decorators import doc
3636

@@ -365,49 +365,125 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
365365
None
366366
"""
367367
key = check_array_indexer(self, key)
368+
indices = self._key_to_indices(key)
368369

369-
if is_integer(key):
370-
key = cast(int, key)
371-
372-
if not is_scalar(value):
373-
raise ValueError("Must pass scalars with scalar indexer")
374-
elif isna(value):
370+
if is_scalar(value):
371+
if isna(value):
375372
value = None
376373
elif not isinstance(value, str):
377374
raise ValueError("Scalar must be NA or str")
375+
value = np.broadcast_to(value, len(indices))
376+
else:
377+
value = np.array(value, dtype=object, copy=True)
378+
for i, v in enumerate(value):
379+
if isna(v):
380+
value[i] = None
381+
elif not isinstance(v, str):
382+
raise ValueError("Scalar must be NA or str")
383+
384+
if len(indices) != len(value):
385+
raise ValueError("Length of indexer and values mismatch")
386+
387+
argsort = np.argsort(indices)
388+
indices = indices[argsort]
389+
value = value[argsort]
390+
391+
self._data = self._set_via_chunk_iteration(indices=indices, value=value)
392+
393+
def _key_to_indices(self, key: int | slice | np.ndarray) -> npt.NDArray[np.intp]:
394+
"""Convert indexing key for self to positional indices."""
395+
if isinstance(key, slice):
396+
indices = np.arange(len(self))[key]
397+
elif is_bool_dtype(key):
398+
key = np.asarray(key)
399+
if len(key) != len(self):
400+
raise ValueError("Length of indexer and values mismatch")
401+
indices = key.nonzero()[0]
402+
else:
403+
key_arr = np.array([key]) if is_integer(key) else np.asarray(key)
404+
indices = np.arange(len(self))[key_arr]
405+
return indices
378406

379-
# Slice data and insert in-between
380-
new_data = [
381-
*self._data[0:key].chunks,
407+
def _set_via_chunk_iteration(
408+
self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any]
409+
) -> pa.ChunkedArray:
410+
"""
411+
Loop through the array chunks and set the new values while
412+
leaving the chunking layout unchanged.
413+
"""
414+
415+
chunk_indices = self._within_chunk_indices(indices)
416+
new_data = []
417+
418+
for i, chunk in enumerate(self._data.iterchunks()):
419+
420+
c_ind = chunk_indices[i]
421+
n = len(c_ind)
422+
c_value, value = value[:n], value[n:]
423+
424+
if n == 1:
425+
# fast path
426+
chunk = self._set_single_index_in_chunk(chunk, c_ind[0], c_value[0])
427+
elif n > 0:
428+
mask = np.zeros(len(chunk), dtype=np.bool_)
429+
mask[c_ind] = True
430+
if not pa_version_under5p0:
431+
if c_value is None or isna(np.array(c_value)).all():
432+
chunk = pc.if_else(mask, None, chunk)
433+
else:
434+
chunk = pc.replace_with_mask(chunk, mask, c_value)
435+
else:
436+
# The pyarrow compute functions were added in
437+
# version 5.0. For prior versions we implement
438+
# our own by converting to numpy and back.
439+
chunk = chunk.to_numpy(zero_copy_only=False)
440+
chunk[mask] = c_value
441+
chunk = pa.array(chunk, type=pa.string())
442+
443+
new_data.append(chunk)
444+
445+
return pa.chunked_array(new_data)
446+
447+
@staticmethod
448+
def _set_single_index_in_chunk(chunk: pa.Array, index: int, value: Any) -> pa.Array:
449+
"""Set a single position in a pyarrow array."""
450+
assert is_scalar(value)
451+
return pa.concat_arrays(
452+
[
453+
chunk[:index],
382454
pa.array([value], type=pa.string()),
383-
*self._data[(key + 1) :].chunks,
455+
chunk[index + 1 :],
384456
]
385-
self._data = pa.chunked_array(new_data)
386-
else:
387-
# Convert to integer indices and iteratively assign.
388-
# TODO: Make a faster variant of this in Arrow upstream.
389-
# This is probably extremely slow.
390-
391-
# Convert all possible input key types to an array of integers
392-
if isinstance(key, slice):
393-
key_array = np.array(range(len(self))[key])
394-
elif is_bool_dtype(key):
395-
# TODO(ARROW-9430): Directly support setitem(booleans)
396-
key_array = np.argwhere(key).flatten()
397-
else:
398-
# TODO(ARROW-9431): Directly support setitem(integers)
399-
key_array = np.asanyarray(key)
457+
)
400458

401-
if is_scalar(value):
402-
value = np.broadcast_to(value, len(key_array))
459+
def _within_chunk_indices(
460+
self, indices: npt.NDArray[np.intp]
461+
) -> list[npt.NDArray[np.intp]]:
462+
"""
463+
Convert indices for self into a list of ndarrays each containing
464+
the indices *within* each chunk of the chunked array.
465+
"""
466+
# indices must be sorted
467+
chunk_indices = []
468+
for start, stop in self._chunk_ranges():
469+
if len(indices) == 0 or indices[0] >= stop:
470+
c_ind = np.array([], dtype=np.intp)
403471
else:
404-
value = np.asarray(value)
472+
n = int(np.searchsorted(indices, stop, side="left"))
473+
c_ind = indices[:n] - start
474+
indices = indices[n:]
475+
chunk_indices.append(c_ind)
476+
return chunk_indices
405477

406-
if len(key_array) != len(value):
407-
raise ValueError("Length of indexer and values mismatch")
408-
409-
for k, v in zip(key_array, value):
410-
self[k] = v
478+
def _chunk_ranges(self) -> list[tuple]:
479+
"""
480+
Return a list of tuples each containing the left (inclusive)
481+
and right (exclusive) bounds of each chunk.
482+
"""
483+
lengths = [len(c) for c in self._data.iterchunks()]
484+
stops = np.cumsum(lengths)
485+
starts = np.concatenate([[0], stops[:-1]])
486+
return list(zip(starts, stops))
411487

412488
def take(
413489
self,

pandas/tests/arrays/string_/test_string_arrow.py

+65
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,68 @@ def test_pyarrow_not_installed_raises():
132132

133133
with pytest.raises(ImportError, match=msg):
134134
ArrowStringArray._from_sequence(["a", None, "b"])
135+
136+
137+
@skip_if_no_pyarrow
138+
@pytest.mark.parametrize("multiple_chunks", [False, True])
139+
@pytest.mark.parametrize(
140+
"key, value, expected",
141+
[
142+
(-1, "XX", ["a", "b", "c", "d", "XX"]),
143+
(1, "XX", ["a", "XX", "c", "d", "e"]),
144+
(1, None, ["a", None, "c", "d", "e"]),
145+
(1, pd.NA, ["a", None, "c", "d", "e"]),
146+
([1, 3], "XX", ["a", "XX", "c", "XX", "e"]),
147+
([1, 3], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
148+
([1, 3], ["XX", None], ["a", "XX", "c", None, "e"]),
149+
([1, 3], ["XX", pd.NA], ["a", "XX", "c", None, "e"]),
150+
([0, -1], ["XX", "YY"], ["XX", "b", "c", "d", "YY"]),
151+
([-1, 0], ["XX", "YY"], ["YY", "b", "c", "d", "XX"]),
152+
(slice(3, None), "XX", ["a", "b", "c", "XX", "XX"]),
153+
(slice(2, 4), ["XX", "YY"], ["a", "b", "XX", "YY", "e"]),
154+
(slice(3, 1, -1), ["XX", "YY"], ["a", "b", "YY", "XX", "e"]),
155+
(slice(None), "XX", ["XX", "XX", "XX", "XX", "XX"]),
156+
([False, True, False, True, False], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
157+
],
158+
)
159+
def test_setitem(multiple_chunks, key, value, expected):
160+
import pyarrow as pa
161+
162+
result = pa.array(list("abcde"))
163+
expected = pa.array(expected)
164+
165+
if multiple_chunks:
166+
result = pa.chunked_array([result[:3], result[3:]])
167+
expected = pa.chunked_array([expected[:3], expected[3:]])
168+
169+
result = ArrowStringArray(result)
170+
expected = ArrowStringArray(expected)
171+
172+
result[key] = value
173+
tm.assert_equal(result, expected)
174+
assert result._data.num_chunks == expected._data.num_chunks
175+
176+
177+
@skip_if_no_pyarrow
178+
def test_setitem_invalid_indexer_raises():
179+
import pyarrow as pa
180+
181+
arr = ArrowStringArray(pa.array(list("abcde")))
182+
183+
with pytest.raises(IndexError, match=None):
184+
arr[5] = "foo"
185+
186+
with pytest.raises(IndexError, match=None):
187+
arr[-6] = "foo"
188+
189+
with pytest.raises(IndexError, match=None):
190+
arr[[0, 5]] = "foo"
191+
192+
with pytest.raises(IndexError, match=None):
193+
arr[[0, -6]] = "foo"
194+
195+
with pytest.raises(IndexError, match=None):
196+
arr[[True, True, False]] = "foo"
197+
198+
with pytest.raises(ValueError, match=None):
199+
arr[[0, 1]] = ["foo", "bar", "baz"]

0 commit comments

Comments
 (0)