Skip to content

Commit 0c24b20

Browse files
authored
REF (string): avoid copy in StringArray factorize (#59551)
* REF: avoid copy in StringArray factorize * mypy fixup * un-xfail
1 parent 59bb3f4 commit 0c24b20

File tree

9 files changed

+34
-41
lines changed

9 files changed

+34
-41
lines changed

pandas/_libs/arrays.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ cdef class NDArrayBacked:
6767
"""
6868
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
6969
70+
The returned array has the same dtype as self.
71+
72+
Caller is responsible for ensuring `values.dtype == self._ndarray.dtype`.
73+
7074
This should round-trip:
7175
self == self._from_backing_data(self._ndarray)
7276
"""

pandas/_libs/hashtable.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ from pandas._libs.khash cimport (
3030
kh_python_hash_func,
3131
khiter_t,
3232
)
33-
from pandas._libs.missing cimport checknull
33+
from pandas._libs.missing cimport (
34+
checknull,
35+
is_matching_na,
36+
)
3437

3538

3639
def get_hashtable_trace_domain():

pandas/_libs/hashtable_class_helper.pxi.in

+15-3
Original file line numberDiff line numberDiff line change
@@ -1171,11 +1171,13 @@ cdef class StringHashTable(HashTable):
11711171
const char **vecs
11721172
khiter_t k
11731173
bint use_na_value
1174+
bint non_null_na_value
11741175

11751176
if return_inverse:
11761177
labels = np.zeros(n, dtype=np.intp)
11771178
uindexer = np.empty(n, dtype=np.int64)
11781179
use_na_value = na_value is not None
1180+
non_null_na_value = not checknull(na_value)
11791181

11801182
# assign pointers and pre-filter out missing (if ignore_na)
11811183
vecs = <const char **>malloc(n * sizeof(char *))
@@ -1186,7 +1188,12 @@ cdef class StringHashTable(HashTable):
11861188

11871189
if (ignore_na
11881190
and (not isinstance(val, str)
1189-
or (use_na_value and val == na_value))):
1191+
or (use_na_value and (
1192+
(non_null_na_value and val == na_value) or
1193+
(not non_null_na_value and is_matching_na(val, na_value)))
1194+
)
1195+
)
1196+
):
11901197
# if missing values do not count as unique values (i.e. if
11911198
# ignore_na is True), we can skip the actual value, and
11921199
# replace the label with na_sentinel directly
@@ -1452,18 +1459,23 @@ cdef class PyObjectHashTable(HashTable):
14521459
object val
14531460
khiter_t k
14541461
bint use_na_value
1455-
1462+
bint non_null_na_value
14561463
if return_inverse:
14571464
labels = np.empty(n, dtype=np.intp)
14581465
use_na_value = na_value is not None
1466+
non_null_na_value = not checknull(na_value)
14591467

14601468
for i in range(n):
14611469
val = values[i]
14621470
hash(val)
14631471

14641472
if ignore_na and (
14651473
checknull(val)
1466-
or (use_na_value and val == na_value)
1474+
or (use_na_value and (
1475+
(non_null_na_value and val == na_value) or
1476+
(not non_null_na_value and is_matching_na(val, na_value))
1477+
)
1478+
)
14671479
):
14681480
# if missing values do not count as unique values (i.e. if
14691481
# ignore_na is True), skip the hashtable entry for them, and

pandas/core/arrays/_mixins.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -496,17 +496,14 @@ def _quantile(
496496
fill_value = self._internal_fill_value
497497

498498
res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
499-
500-
res_values = self._cast_quantile_result(res_values)
501-
return self._from_backing_data(res_values)
502-
503-
# TODO: see if we can share this with other dispatch-wrapping methods
504-
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
505-
"""
506-
Cast the result of quantile_with_mask to an appropriate dtype
507-
to pass to _from_backing_data in _quantile.
508-
"""
509-
return res_values
499+
if res_values.dtype == self._ndarray.dtype:
500+
return self._from_backing_data(res_values)
501+
else:
502+
# e.g. test_quantile_empty we are empty integer dtype and res_values
503+
# has floating dtype
504+
# TODO: technically __init__ isn't defined here.
505+
# Should we raise NotImplementedError and handle this on NumpyEA?
506+
return type(self)(res_values) # type: ignore[call-arg]
510507

511508
# ------------------------------------------------------------------------
512509
# numpy-like methods

pandas/core/arrays/categorical.py

-5
Original file line numberDiff line numberDiff line change
@@ -2495,11 +2495,6 @@ def unique(self) -> Self:
24952495
"""
24962496
return super().unique()
24972497

2498-
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
2499-
# make sure we have correct itemsize for resulting codes
2500-
assert res_values.dtype == self._ndarray.dtype
2501-
return res_values
2502-
25032498
def equals(self, other: object) -> bool:
25042499
"""
25052500
Returns True if categorical arrays are equal.

pandas/core/arrays/numpy_.py

-3
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ def _from_sequence(
137137
result = result.copy()
138138
return cls(result)
139139

140-
def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray:
141-
return type(self)(arr)
142-
143140
# ------------------------------------------------------------------------
144141
# Data
145142

pandas/core/arrays/string_.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,10 @@ def __arrow_array__(self, type=None):
659659
values[self.isna()] = None
660660
return pa.array(values, type=type, from_pandas=True)
661661

662-
def _values_for_factorize(self) -> tuple[np.ndarray, None]:
662+
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
663663
arr = self._ndarray.copy()
664-
mask = self.isna()
665-
arr[mask] = None
666-
return arr, None
664+
665+
return arr, self.dtype.na_value
667666

668667
def __setitem__(self, key, value) -> None:
669668
value = extract_array(value, extract_numpy=True)
@@ -873,8 +872,3 @@ def _from_sequence(
873872
if dtype is None:
874873
dtype = StringDtype(storage="python", na_value=np.nan)
875874
return super()._from_sequence(scalars, dtype=dtype, copy=copy)
876-
877-
def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics:
878-
# need to override NumpyExtensionArray._from_backing_data to ensure
879-
# we always preserve the dtype
880-
return NDArrayBacked._from_backing_data(self, arr)

pandas/tests/groupby/test_groupby_dropna.py

-3
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,6 @@ def test_groupby_dropna_with_multiindex_input(input_index, keys, series):
387387
tm.assert_equal(result, expected)
388388

389389

390-
@pytest.mark.xfail(
391-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
392-
)
393390
def test_groupby_nan_included():
394391
# GH 35646
395392
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}

pandas/tests/window/test_rolling.py

-6
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
import numpy as np
77
import pytest
88

9-
from pandas._config import using_string_dtype
10-
119
from pandas.compat import (
12-
HAS_PYARROW,
1310
IS64,
1411
is_platform_arm,
1512
is_platform_power,
@@ -1329,9 +1326,6 @@ def test_rolling_corr_timedelta_index(index, window):
13291326
tm.assert_almost_equal(result, expected)
13301327

13311328

1332-
@pytest.mark.xfail(
1333-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
1334-
)
13351329
def test_groupby_rolling_nan_included():
13361330
# GH 35542
13371331
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}

0 commit comments

Comments
 (0)