Skip to content

REF (string): avoid copy in StringArray factorize #59551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/_libs/arrays.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ cdef class NDArrayBacked:
"""
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.

The returned array has the same dtype as self.

Caller is responsible for ensuring `values.dtype == self._ndarray.dtype`.

This should round-trip:
self == self._from_backing_data(self._ndarray)
"""
Expand Down
5 changes: 4 additions & 1 deletion pandas/_libs/hashtable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ from pandas._libs.khash cimport (
kh_python_hash_func,
khiter_t,
)
from pandas._libs.missing cimport checknull
from pandas._libs.missing cimport (
checknull,
is_matching_na,
)


def get_hashtable_trace_domain():
Expand Down
18 changes: 15 additions & 3 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1171,11 +1171,13 @@ cdef class StringHashTable(HashTable):
const char **vecs
khiter_t k
bint use_na_value
bint non_null_na_value

if return_inverse:
labels = np.zeros(n, dtype=np.intp)
uindexer = np.empty(n, dtype=np.int64)
use_na_value = na_value is not None
non_null_na_value = not checknull(na_value)

# assign pointers and pre-filter out missing (if ignore_na)
vecs = <const char **>malloc(n * sizeof(char *))
Expand All @@ -1186,7 +1188,12 @@ cdef class StringHashTable(HashTable):

if (ignore_na
and (not isinstance(val, str)
or (use_na_value and val == na_value))):
or (use_na_value and (
(non_null_na_value and val == na_value) or
(not non_null_na_value and is_matching_na(val, na_value)))
)
)
):
# if missing values do not count as unique values (i.e. if
# ignore_na is True), we can skip the actual value, and
# replace the label with na_sentinel directly
Expand Down Expand Up @@ -1452,18 +1459,23 @@ cdef class PyObjectHashTable(HashTable):
object val
khiter_t k
bint use_na_value

bint non_null_na_value
if return_inverse:
labels = np.empty(n, dtype=np.intp)
use_na_value = na_value is not None
non_null_na_value = not checknull(na_value)

for i in range(n):
val = values[i]
hash(val)

if ignore_na and (
checknull(val)
or (use_na_value and val == na_value)
or (use_na_value and (
(non_null_na_value and val == na_value) or
(not non_null_na_value and is_matching_na(val, na_value))
)
)
):
# if missing values do not count as unique values (i.e. if
# ignore_na is True), skip the hashtable entry for them, and
Expand Down
19 changes: 8 additions & 11 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,14 @@ def _quantile(
fill_value = self._internal_fill_value

res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)

res_values = self._cast_quantile_result(res_values)
return self._from_backing_data(res_values)

# TODO: see if we can share this with other dispatch-wrapping methods
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
"""
Cast the result of quantile_with_mask to an appropriate dtype
to pass to _from_backing_data in _quantile.
"""
return res_values
if res_values.dtype == self._ndarray.dtype:
return self._from_backing_data(res_values)
else:
# e.g. test_quantile_empty we are empty integer dtype and res_values
# has floating dtype
# TODO: technically __init__ isn't defined here.
# Should we raise NotImplementedError and handle this on NumpyEA?
return type(self)(res_values) # type: ignore[call-arg]

# ------------------------------------------------------------------------
# numpy-like methods
Expand Down
5 changes: 0 additions & 5 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,11 +2495,6 @@ def unique(self) -> Self:
"""
return super().unique()

def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
# make sure we have correct itemsize for resulting codes
assert res_values.dtype == self._ndarray.dtype
return res_values

def equals(self, other: object) -> bool:
"""
Returns True if categorical arrays are equal.
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def _from_sequence(
result = result.copy()
return cls(result)

def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray:
return type(self)(arr)

# ------------------------------------------------------------------------
# Data

Expand Down
12 changes: 3 additions & 9 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,10 @@ def __arrow_array__(self, type=None):
values[self.isna()] = None
return pa.array(values, type=type, from_pandas=True)

def _values_for_factorize(self) -> tuple[np.ndarray, None]:
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
arr = self._ndarray.copy()
mask = self.isna()
arr[mask] = None
return arr, None

return arr, self.dtype.na_value

def __setitem__(self, key, value) -> None:
value = extract_array(value, extract_numpy=True)
Expand Down Expand Up @@ -873,8 +872,3 @@ def _from_sequence(
if dtype is None:
dtype = StringDtype(storage="python", na_value=np.nan)
return super()._from_sequence(scalars, dtype=dtype, copy=copy)

def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics:
# need to override NumpyExtensionArray._from_backing_data to ensure
# we always preserve the dtype
return NDArrayBacked._from_backing_data(self, arr)
3 changes: 0 additions & 3 deletions pandas/tests/groupby/test_groupby_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def test_groupby_dropna_with_multiindex_input(input_index, keys, series):
tm.assert_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_groupby_nan_included():
# GH 35646
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}
Expand Down
6 changes: 0 additions & 6 deletions pandas/tests/window/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import (
HAS_PYARROW,
IS64,
is_platform_arm,
is_platform_power,
Expand Down Expand Up @@ -1329,9 +1326,6 @@ def test_rolling_corr_timedelta_index(index, window):
tm.assert_almost_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_groupby_rolling_nan_included():
# GH 35542
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}
Expand Down