Skip to content

REF/PERF: ArrowStringArray.__setitem__ #46400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions asv_bench/benchmarks/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pyarrow as pa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if the policy should change, but AFAIK we guard the pyarrow import in the other benchmarks as it's an optional dependency and raise NotImplementedError so that the benchmarks get skipped when pyarrow not installed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks for pointing that out.


import pandas as pd

from .pandas_vb_common import tm


class BooleanArray:
def setup(self):
Expand Down Expand Up @@ -39,3 +42,28 @@ def time_constructor(self):

def time_from_integer_array(self):
pd.array(self.values_integer, dtype="Int64")


class ArrowStringArray:

params = [False, True]
param_names = ["multiple_chunks"]

def setup(self, multiple_chunks):
strings = tm.rands_array(3, 10_000)
if multiple_chunks:
chunks = [strings[i : i + 100] for i in range(0, len(strings), 100)]
self.array = pd.arrays.ArrowStringArray(pa.chunked_array(chunks))
else:
self.array = pd.arrays.ArrowStringArray(pa.array(strings))

def time_setitem(self, multiple_chunks):
for i in range(200):
self.array[i] = "foo"

def time_setitem_list(self, multiple_chunks):
indexer = list(range(0, 50)) + list(range(-50, 0))
self.array[indexer] = ["foo"] * len(indexer)

def time_setitem_slice(self, multiple_chunks):
self.array[::10] = "foo"
2 changes: 2 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
pa_version_under5p0,
)

PY39 = sys.version_info >= (3, 9)
Expand Down Expand Up @@ -148,4 +149,5 @@ def get_lzma_file():
"pa_version_under2p0",
"pa_version_under3p0",
"pa_version_under4p0",
"pa_version_under5p0",
]
141 changes: 107 additions & 34 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TYPE_CHECKING,
Any,
Union,
cast,
overload,
)

Expand All @@ -31,6 +30,7 @@
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
pa_version_under5p0,
)
from pandas.util._decorators import doc

Expand Down Expand Up @@ -362,49 +362,122 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
None
"""
key = check_array_indexer(self, key)
value_is_scalar = is_scalar(value)

if is_integer(key):
key = cast(int, key)

if not is_scalar(value):
raise ValueError("Must pass scalars with scalar indexer")
elif isna(value):
# NA -> None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would create a helper method like _validate_key() to encapsulate all of this (ok on this class for now, but we likey want to push this to the ArrowExtensionArray (or maybe we need a ArrowIndexingMixin or similar), that can be later (or here if convenient).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored pretty extensively, this logic is now self contained

if value_is_scalar:
if isna(value):
value = None
elif not isinstance(value, str):
raise ValueError("Scalar must be NA or str")

# Slice data and insert in-between
new_data = [
*self._data[0:key].chunks,
pa.array([value], type=pa.string()),
*self._data[(key + 1) :].chunks,
]
self._data = pa.chunked_array(new_data)
else:
# Convert to integer indices and iteratively assign.
# TODO: Make a faster variant of this in Arrow upstream.
# This is probably extremely slow.

# Convert all possible input key types to an array of integers
if isinstance(key, slice):
key_array = np.array(range(len(self))[key])
elif is_bool_dtype(key):
# TODO(ARROW-9430): Directly support setitem(booleans)
key_array = np.argwhere(key).flatten()
else:
# TODO(ARROW-9431): Directly support setitem(integers)
key_array = np.asanyarray(key)
value = np.asarray(value)
value[isna(value)] = None

# reorder values to align with the mask positions
if is_bool_dtype(key):
pass
elif isinstance(key, slice):
if not value_is_scalar and key.step is not None and key.step < 0:
value = value[::-1]
else:
if not value_is_scalar:
key = np.asarray(key)
if len(key) != len(value):
raise ValueError("Length of indexer and values mismatch")

if np.any(key < -len(self)):
min_key = np.asarray(key).min()
raise IndexError(
f"index {min_key} is out of bounds for array of length {len(self)}"
)
if np.any(key >= len(self)):
max_key = np.asarray(key).max()
raise IndexError(
f"index {max_key} is out of bounds for array of length {len(self)}"
)

if is_scalar(value):
value = np.broadcast_to(value, len(key_array))
# convert negative indices to positive before sorting
if is_integer(key):
if key < 0:
key += len(self)
else:
value = np.asarray(value)
key[key < 0] += len(self)
if not value_is_scalar:
value = value[np.argsort(key)]

# fast path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here would do something along the lines of

if can_fast_path(key):
      return self.set_with_fast_path(....)

return self.set_via_chunk_iteration()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if is_integer(key) and value_is_scalar and self._data.num_chunks == 1:
chunk = pa.concat_arrays(
[
self._data.chunks[0][:key],
pa.array([value], type=pa.string()),
self._data.chunks[0][key + 1 :],
]
)
self._data = pa.chunked_array([chunk])
return

if len(key_array) != len(value):
# create mask for positions to set
if is_bool_dtype(key):
mask = key
else:
mask = np.zeros(len(self), dtype=np.bool_)
mask[key] = True

if not value_is_scalar:
if len(value) != np.sum(mask):
raise ValueError("Length of indexer and values mismatch")

for k, v in zip(key_array, value):
self[k] = v
indices = mask.nonzero()[0]

# loop through the array chunks and set the new values while
# leaving the chunking layout unchanged
start = stop = 0
new_data = []

for chunk in self._data.iterchunks():
start, stop = stop, stop + len(chunk)

if len(indices) == 0 or indices[0] >= stop:
new_data.append(chunk)
continue

n = np.searchsorted(indices, np.intp(stop), side="left")
c_indices, indices = indices[:n], indices[n:]

if value_is_scalar:
c_value = value
else:
c_value, value = value[:n], value[n:]

if n == 1:
# fast path
idx = c_indices[0] - start
v = [c_value] if value_is_scalar else c_value
chunk = pa.concat_arrays(
[
chunk[:idx],
pa.array(v, type=pa.string()),
chunk[idx + 1 :],
]
)

elif n > 0:
submask = mask[start:stop]
if not pa_version_under5p0:
chunk = pc.replace_with_mask(chunk, submask, c_value)
else:
# The replace_with_mask compute function was added in
# version 5.0. For prior versions we implement our own
# by converting to numpy and back.
chunk = chunk.to_numpy(zero_copy_only=False)
chunk[submask] = c_value
chunk = pa.array(chunk, type=pa.string())

new_data.append(chunk)

self._data = pa.chunked_array(new_data)

def take(
self,
Expand Down
65 changes: 65 additions & 0 deletions pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,68 @@ def test_pyarrow_not_installed_raises():

with pytest.raises(ImportError, match=msg):
ArrowStringArray._from_sequence(["a", None, "b"])


@skip_if_no_pyarrow
@pytest.mark.parametrize("multiple_chunks", [False, True])
@pytest.mark.parametrize(
"key, value, expected",
[
(-1, "XX", ["a", "b", "c", "d", "XX"]),
(1, "XX", ["a", "XX", "c", "d", "e"]),
(1, None, ["a", None, "c", "d", "e"]),
(1, pd.NA, ["a", None, "c", "d", "e"]),
([1, 3], "XX", ["a", "XX", "c", "XX", "e"]),
([1, 3], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
([1, 3], ["XX", None], ["a", "XX", "c", None, "e"]),
([1, 3], ["XX", pd.NA], ["a", "XX", "c", None, "e"]),
([0, -1], ["XX", "YY"], ["XX", "b", "c", "d", "YY"]),
([-1, 0], ["XX", "YY"], ["YY", "b", "c", "d", "XX"]),
(slice(3, None), "XX", ["a", "b", "c", "XX", "XX"]),
(slice(2, 4), ["XX", "YY"], ["a", "b", "XX", "YY", "e"]),
(slice(3, 1, -1), ["XX", "YY"], ["a", "b", "YY", "XX", "e"]),
(slice(None), "XX", ["XX", "XX", "XX", "XX", "XX"]),
([False, True, False, True, False], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
],
)
def test_setitem(multiple_chunks, key, value, expected):
import pyarrow as pa

result = pa.array(list("abcde"))
expected = pa.array(expected)

if multiple_chunks:
result = pa.chunked_array([result[:3], result[3:]])
expected = pa.chunked_array([expected[:3], expected[3:]])

result = ArrowStringArray(result)
expected = ArrowStringArray(expected)

result[key] = value
tm.assert_equal(result, expected)
assert result._data.num_chunks == expected._data.num_chunks


@skip_if_no_pyarrow
def test_setitem_invalid_indexer_raises():
import pyarrow as pa

arr = ArrowStringArray(pa.array(list("abcde")))

with pytest.raises(IndexError, match=None):
arr[5] = "foo"

with pytest.raises(IndexError, match=None):
arr[-6] = "foo"

with pytest.raises(IndexError, match=None):
arr[[0, 5]] = "foo"

with pytest.raises(IndexError, match=None):
arr[[0, -6]] = "foo"

with pytest.raises(IndexError, match=None):
arr[[True, True, False]] = "foo"

with pytest.raises(ValueError, match=None):
arr[[0, 1]] = ["foo", "bar", "baz"]