|
6 | 6 | TYPE_CHECKING,
|
7 | 7 | Any,
|
8 | 8 | Union,
|
9 |
| - cast, |
10 | 9 | overload,
|
11 | 10 | )
|
12 | 11 |
|
|
31 | 30 | pa_version_under2p0,
|
32 | 31 | pa_version_under3p0,
|
33 | 32 | pa_version_under4p0,
|
| 33 | + pa_version_under5p0, |
34 | 34 | )
|
35 | 35 | from pandas.util._decorators import doc
|
36 | 36 |
|
@@ -365,49 +365,125 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
|
365 | 365 | None
|
366 | 366 | """
|
367 | 367 | key = check_array_indexer(self, key)
|
| 368 | + indices = self._key_to_indices(key) |
368 | 369 |
|
369 |
| - if is_integer(key): |
370 |
| - key = cast(int, key) |
371 |
| - |
372 |
| - if not is_scalar(value): |
373 |
| - raise ValueError("Must pass scalars with scalar indexer") |
374 |
| - elif isna(value): |
| 370 | + if is_scalar(value): |
| 371 | + if isna(value): |
375 | 372 | value = None
|
376 | 373 | elif not isinstance(value, str):
|
377 | 374 | raise ValueError("Scalar must be NA or str")
|
| 375 | + value = np.broadcast_to(value, len(indices)) |
| 376 | + else: |
| 377 | + value = np.array(value, dtype=object, copy=True) |
| 378 | + for i, v in enumerate(value): |
| 379 | + if isna(v): |
| 380 | + value[i] = None |
| 381 | + elif not isinstance(v, str): |
| 382 | + raise ValueError("Scalar must be NA or str") |
| 383 | + |
| 384 | + if len(indices) != len(value): |
| 385 | + raise ValueError("Length of indexer and values mismatch") |
| 386 | + |
| 387 | + argsort = np.argsort(indices) |
| 388 | + indices = indices[argsort] |
| 389 | + value = value[argsort] |
| 390 | + |
| 391 | + self._data = self._set_via_chunk_iteration(indices=indices, value=value) |
| 392 | + |
| 393 | + def _key_to_indices(self, key: int | slice | np.ndarray) -> npt.NDArray[np.intp]: |
| 394 | + """Convert indexing key for self to positional indices.""" |
| 395 | + if isinstance(key, slice): |
| 396 | + indices = np.arange(len(self))[key] |
| 397 | + elif is_bool_dtype(key): |
| 398 | + key = np.asarray(key) |
| 399 | + if len(key) != len(self): |
| 400 | + raise ValueError("Length of indexer and values mismatch") |
| 401 | + indices = key.nonzero()[0] |
| 402 | + else: |
| 403 | + key_arr = np.array([key]) if is_integer(key) else np.asarray(key) |
| 404 | + indices = np.arange(len(self))[key_arr] |
| 405 | + return indices |
378 | 406 |
|
379 |
| - # Slice data and insert in-between |
380 |
| - new_data = [ |
381 |
| - *self._data[0:key].chunks, |
| 407 | + def _set_via_chunk_iteration( |
| 408 | + self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any] |
| 409 | + ) -> pa.ChunkedArray: |
| 410 | + """ |
| 411 | + Loop through the array chunks and set the new values while |
| 412 | + leaving the chunking layout unchanged. |
| 413 | + """ |
| 414 | + |
| 415 | + chunk_indices = self._within_chunk_indices(indices) |
| 416 | + new_data = [] |
| 417 | + |
| 418 | + for i, chunk in enumerate(self._data.iterchunks()): |
| 419 | + |
| 420 | + c_ind = chunk_indices[i] |
| 421 | + n = len(c_ind) |
| 422 | + c_value, value = value[:n], value[n:] |
| 423 | + |
| 424 | + if n == 1: |
| 425 | + # fast path |
| 426 | + chunk = self._set_single_index_in_chunk(chunk, c_ind[0], c_value[0]) |
| 427 | + elif n > 0: |
| 428 | + mask = np.zeros(len(chunk), dtype=np.bool_) |
| 429 | + mask[c_ind] = True |
| 430 | + if not pa_version_under5p0: |
| 431 | + if c_value is None or isna(np.array(c_value)).all(): |
| 432 | + chunk = pc.if_else(mask, None, chunk) |
| 433 | + else: |
| 434 | + chunk = pc.replace_with_mask(chunk, mask, c_value) |
| 435 | + else: |
| 436 | + # The pyarrow compute functions were added in |
| 437 | + # version 5.0. For prior versions we implement |
| 438 | + # our own by converting to numpy and back. |
| 439 | + chunk = chunk.to_numpy(zero_copy_only=False) |
| 440 | + chunk[mask] = c_value |
| 441 | + chunk = pa.array(chunk, type=pa.string()) |
| 442 | + |
| 443 | + new_data.append(chunk) |
| 444 | + |
| 445 | + return pa.chunked_array(new_data) |
| 446 | + |
| 447 | + @staticmethod |
| 448 | + def _set_single_index_in_chunk(chunk: pa.Array, index: int, value: Any) -> pa.Array: |
| 449 | + """Set a single position in a pyarrow array.""" |
| 450 | + assert is_scalar(value) |
| 451 | + return pa.concat_arrays( |
| 452 | + [ |
| 453 | + chunk[:index], |
382 | 454 | pa.array([value], type=pa.string()),
|
383 |
| - *self._data[(key + 1) :].chunks, |
| 455 | + chunk[index + 1 :], |
384 | 456 | ]
|
385 |
| - self._data = pa.chunked_array(new_data) |
386 |
| - else: |
387 |
| - # Convert to integer indices and iteratively assign. |
388 |
| - # TODO: Make a faster variant of this in Arrow upstream. |
389 |
| - # This is probably extremely slow. |
390 |
| - |
391 |
| - # Convert all possible input key types to an array of integers |
392 |
| - if isinstance(key, slice): |
393 |
| - key_array = np.array(range(len(self))[key]) |
394 |
| - elif is_bool_dtype(key): |
395 |
| - # TODO(ARROW-9430): Directly support setitem(booleans) |
396 |
| - key_array = np.argwhere(key).flatten() |
397 |
| - else: |
398 |
| - # TODO(ARROW-9431): Directly support setitem(integers) |
399 |
| - key_array = np.asanyarray(key) |
| 457 | + ) |
400 | 458 |
|
401 |
| - if is_scalar(value): |
402 |
| - value = np.broadcast_to(value, len(key_array)) |
| 459 | + def _within_chunk_indices( |
| 460 | + self, indices: npt.NDArray[np.intp] |
| 461 | + ) -> list[npt.NDArray[np.intp]]: |
| 462 | + """ |
| 463 | + Convert indices for self into a list of ndarrays each containing |
| 464 | + the indices *within* each chunk of the chunked array. |
| 465 | + """ |
| 466 | + # indices must be sorted |
| 467 | + chunk_indices = [] |
| 468 | + for start, stop in self._chunk_ranges(): |
| 469 | + if len(indices) == 0 or indices[0] >= stop: |
| 470 | + c_ind = np.array([], dtype=np.intp) |
403 | 471 | else:
|
404 |
| - value = np.asarray(value) |
| 472 | + n = int(np.searchsorted(indices, stop, side="left")) |
| 473 | + c_ind = indices[:n] - start |
| 474 | + indices = indices[n:] |
| 475 | + chunk_indices.append(c_ind) |
| 476 | + return chunk_indices |
405 | 477 |
|
406 |
| - if len(key_array) != len(value): |
407 |
| - raise ValueError("Length of indexer and values mismatch") |
408 |
| - |
409 |
| - for k, v in zip(key_array, value): |
410 |
| - self[k] = v |
| 478 | + def _chunk_ranges(self) -> list[tuple]: |
| 479 | + """ |
| 480 | + Return a list of tuples each containing the left (inclusive) |
| 481 | + and right (exclusive) bounds of each chunk. |
| 482 | + """ |
| 483 | + lengths = [len(c) for c in self._data.iterchunks()] |
| 484 | + stops = np.cumsum(lengths) |
| 485 | + starts = np.concatenate([[0], stops[:-1]]) |
| 486 | + return list(zip(starts, stops)) |
411 | 487 |
|
412 | 488 | def take(
|
413 | 489 | self,
|
|
0 commit comments