-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF/PERF: ArrowStringArray.__setitem__ #46400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
e379a22
e21c4ff
0e35f6a
f292054
773f375
f44bcbb
76a25a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
TYPE_CHECKING, | ||
Any, | ||
Union, | ||
cast, | ||
overload, | ||
) | ||
|
||
|
@@ -31,6 +30,7 @@ | |
pa_version_under2p0, | ||
pa_version_under3p0, | ||
pa_version_under4p0, | ||
pa_version_under5p0, | ||
) | ||
from pandas.util._decorators import doc | ||
|
||
|
@@ -362,49 +362,122 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None: | |
None | ||
""" | ||
key = check_array_indexer(self, key) | ||
value_is_scalar = is_scalar(value) | ||
|
||
if is_integer(key): | ||
key = cast(int, key) | ||
|
||
if not is_scalar(value): | ||
raise ValueError("Must pass scalars with scalar indexer") | ||
elif isna(value): | ||
# NA -> None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i would create a helper method like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refactored pretty extensively, this logic is now self contained |
||
if value_is_scalar: | ||
if isna(value): | ||
value = None | ||
elif not isinstance(value, str): | ||
raise ValueError("Scalar must be NA or str") | ||
|
||
# Slice data and insert in-between | ||
new_data = [ | ||
*self._data[0:key].chunks, | ||
pa.array([value], type=pa.string()), | ||
*self._data[(key + 1) :].chunks, | ||
] | ||
self._data = pa.chunked_array(new_data) | ||
else: | ||
# Convert to integer indices and iteratively assign. | ||
# TODO: Make a faster variant of this in Arrow upstream. | ||
# This is probably extremely slow. | ||
|
||
# Convert all possible input key types to an array of integers | ||
if isinstance(key, slice): | ||
key_array = np.array(range(len(self))[key]) | ||
elif is_bool_dtype(key): | ||
# TODO(ARROW-9430): Directly support setitem(booleans) | ||
key_array = np.argwhere(key).flatten() | ||
else: | ||
# TODO(ARROW-9431): Directly support setitem(integers) | ||
key_array = np.asanyarray(key) | ||
value = np.asarray(value) | ||
value[isna(value)] = None | ||
|
||
# reorder values to align with the mask positions | ||
if is_bool_dtype(key): | ||
pass | ||
elif isinstance(key, slice): | ||
if not value_is_scalar and key.step is not None and key.step < 0: | ||
value = value[::-1] | ||
else: | ||
if not value_is_scalar: | ||
key = np.asarray(key) | ||
if len(key) != len(value): | ||
raise ValueError("Length of indexer and values mismatch") | ||
|
||
if np.any(key < -len(self)): | ||
min_key = np.asarray(key).min() | ||
raise IndexError( | ||
f"index {min_key} is out of bounds for array of length {len(self)}" | ||
) | ||
if np.any(key >= len(self)): | ||
max_key = np.asarray(key).max() | ||
raise IndexError( | ||
f"index {max_key} is out of bounds for array of length {len(self)}" | ||
) | ||
|
||
if is_scalar(value): | ||
value = np.broadcast_to(value, len(key_array)) | ||
# convert negative indices to positive before sorting | ||
if is_integer(key): | ||
if key < 0: | ||
key += len(self) | ||
else: | ||
value = np.asarray(value) | ||
key[key < 0] += len(self) | ||
if not value_is_scalar: | ||
value = value[np.argsort(key)] | ||
|
||
# fast path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same thing here would do something along the lines of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if is_integer(key) and value_is_scalar and self._data.num_chunks == 1: | ||
chunk = pa.concat_arrays( | ||
[ | ||
self._data.chunks[0][:key], | ||
pa.array([value], type=pa.string()), | ||
self._data.chunks[0][key + 1 :], | ||
] | ||
) | ||
self._data = pa.chunked_array([chunk]) | ||
return | ||
|
||
if len(key_array) != len(value): | ||
# create mask for positions to set | ||
if is_bool_dtype(key): | ||
mask = key | ||
else: | ||
mask = np.zeros(len(self), dtype=np.bool_) | ||
mask[key] = True | ||
|
||
if not value_is_scalar: | ||
if len(value) != np.sum(mask): | ||
raise ValueError("Length of indexer and values mismatch") | ||
|
||
for k, v in zip(key_array, value): | ||
self[k] = v | ||
indices = mask.nonzero()[0] | ||
|
||
# loop through the array chunks and set the new values while | ||
# leaving the chunking layout unchanged | ||
start = stop = 0 | ||
new_data = [] | ||
|
||
for chunk in self._data.iterchunks(): | ||
start, stop = stop, stop + len(chunk) | ||
|
||
if len(indices) == 0 or indices[0] >= stop: | ||
new_data.append(chunk) | ||
continue | ||
|
||
n = np.searchsorted(indices, np.intp(stop), side="left") | ||
c_indices, indices = indices[:n], indices[n:] | ||
|
||
if value_is_scalar: | ||
c_value = value | ||
else: | ||
c_value, value = value[:n], value[n:] | ||
|
||
if n == 1: | ||
# fast path | ||
idx = c_indices[0] - start | ||
v = [c_value] if value_is_scalar else c_value | ||
chunk = pa.concat_arrays( | ||
[ | ||
chunk[:idx], | ||
pa.array(v, type=pa.string()), | ||
chunk[idx + 1 :], | ||
] | ||
) | ||
|
||
elif n > 0: | ||
submask = mask[start:stop] | ||
if not pa_version_under5p0: | ||
chunk = pc.replace_with_mask(chunk, submask, c_value) | ||
else: | ||
# The replace_with_mask compute function was added in | ||
# version 5.0. For prior versions we implement our own | ||
# by converting to numpy and back. | ||
chunk = chunk.to_numpy(zero_copy_only=False) | ||
chunk[submask] = c_value | ||
chunk = pa.array(chunk, type=pa.string()) | ||
|
||
new_data.append(chunk) | ||
|
||
self._data = pa.chunked_array(new_data) | ||
|
||
def take( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if the policy should change, but AFAIK we guard the pyarrow import in the other benchmarks as it's an optional dependency and raise NotImplementedError so that the benchmarks get skipped when pyarrow not installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks for pointing that out.