|
22 | 22 | pa_version_under4p0,
|
23 | 23 | pa_version_under5p0,
|
24 | 24 | pa_version_under6p0,
|
| 25 | + pa_version_under7p0, |
| 26 | +) |
| 27 | +from pandas.util._decorators import ( |
| 28 | + deprecate_nonkeyword_arguments, |
| 29 | + doc, |
25 | 30 | )
|
26 |
| -from pandas.util._decorators import doc |
27 | 31 |
|
28 | 32 | from pandas.core.dtypes.common import (
|
29 | 33 | is_array_like,
|
@@ -418,6 +422,58 @@ def isna(self) -> npt.NDArray[np.bool_]:
|
418 | 422 | else:
|
419 | 423 | return self._data.is_null().to_numpy()
|
420 | 424 |
|
| 425 | + @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) |
| 426 | + def argsort( |
| 427 | + self, |
| 428 | + ascending: bool = True, |
| 429 | + kind: str = "quicksort", |
| 430 | + na_position: str = "last", |
| 431 | + *args, |
| 432 | + **kwargs, |
| 433 | + ) -> np.ndarray: |
| 434 | + order = "ascending" if ascending else "descending" |
| 435 | + null_placement = {"last": "at_end", "first": "at_start"}.get(na_position, None) |
| 436 | + if null_placement is None or pa_version_under7p0: |
| 437 | + # Although pc.array_sort_indices exists in version 6 |
| 438 | + # there's a bug that affects the pa.ChunkedArray backing |
| 439 | + # https://issues.apache.org/jira/browse/ARROW-12042 |
| 440 | + fallback_performancewarning("7") |
| 441 | + return super().argsort( |
| 442 | + ascending=ascending, kind=kind, na_position=na_position |
| 443 | + ) |
| 444 | + |
| 445 | + result = pc.array_sort_indices( |
| 446 | + self._data, order=order, null_placement=null_placement |
| 447 | + ) |
| 448 | + if pa_version_under2p0: |
| 449 | + np_result = result.to_pandas().values |
| 450 | + else: |
| 451 | + np_result = result.to_numpy() |
| 452 | + return np_result.astype(np.intp, copy=False) |
| 453 | + |
| 454 | + def _argmin_max(self, skipna: bool, method: str) -> int: |
| 455 | + if self._data.length() in (0, self._data.null_count) or ( |
| 456 | + self._hasna and not skipna |
| 457 | + ): |
| 458 | + # For empty or all null, pyarrow returns -1 but pandas expects TypeError |
| 459 | + # For skipna=False and data w/ null, pandas expects NotImplementedError |
| 460 | + # let ExtensionArray.arg{max|min} raise |
| 461 | + return getattr(super(), f"arg{method}")(skipna=skipna) |
| 462 | + |
| 463 | + if pa_version_under6p0: |
| 464 | + raise NotImplementedError( |
| 465 | + f"arg{method} only implemented for pyarrow version >= 6.0" |
| 466 | + ) |
| 467 | + |
| 468 | + value = getattr(pc, method)(self._data, skip_nulls=skipna) |
| 469 | + return pc.index(self._data, value).as_py() |
| 470 | + |
| 471 | + def argmin(self, skipna: bool = True) -> int: |
| 472 | + return self._argmin_max(skipna, "min") |
| 473 | + |
| 474 | + def argmax(self, skipna: bool = True) -> int: |
| 475 | + return self._argmin_max(skipna, "max") |
| 476 | + |
421 | 477 | def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
|
422 | 478 | """
|
423 | 479 | Return a shallow copy of the array.
|
|
0 commit comments