Skip to content

Commit 60175cc

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
REF (string): avoid copy in StringArray factorize (#59551)
* REF: avoid copy in StringArray factorize * mypy fixup * un-xfail
1 parent c9d4b1b commit 60175cc

File tree

9 files changed

+34
-41
lines changed

9 files changed

+34
-41
lines changed

pandas/_libs/arrays.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ cdef class NDArrayBacked:
6767
"""
6868
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
6969
70+
The returned array has the same dtype as self.
71+
72+
Caller is responsible for ensuring `values.dtype == self._ndarray.dtype`.
73+
7074
This should round-trip:
7175
self == self._from_backing_data(self._ndarray)
7276
"""

pandas/_libs/hashtable.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ from pandas._libs.khash cimport (
3333
kh_python_hash_func,
3434
khiter_t,
3535
)
36-
from pandas._libs.missing cimport checknull
36+
from pandas._libs.missing cimport (
37+
checknull,
38+
is_matching_na,
39+
)
3740

3841

3942
def get_hashtable_trace_domain():

pandas/_libs/hashtable_class_helper.pxi.in

+15-3
Original file line numberDiff line numberDiff line change
@@ -1121,11 +1121,13 @@ cdef class StringHashTable(HashTable):
11211121
const char **vecs
11221122
khiter_t k
11231123
bint use_na_value
1124+
bint non_null_na_value
11241125

11251126
if return_inverse:
11261127
labels = np.zeros(n, dtype=np.intp)
11271128
uindexer = np.empty(n, dtype=np.int64)
11281129
use_na_value = na_value is not None
1130+
non_null_na_value = not checknull(na_value)
11291131

11301132
# assign pointers and pre-filter out missing (if ignore_na)
11311133
vecs = <const char **>malloc(n * sizeof(char *))
@@ -1134,7 +1136,12 @@ cdef class StringHashTable(HashTable):
11341136

11351137
if (ignore_na
11361138
and (not isinstance(val, str)
1137-
or (use_na_value and val == na_value))):
1139+
or (use_na_value and (
1140+
(non_null_na_value and val == na_value) or
1141+
(not non_null_na_value and is_matching_na(val, na_value)))
1142+
)
1143+
)
1144+
):
11381145
# if missing values do not count as unique values (i.e. if
11391146
# ignore_na is True), we can skip the actual value, and
11401147
# replace the label with na_sentinel directly
@@ -1400,18 +1407,23 @@ cdef class PyObjectHashTable(HashTable):
14001407
object val
14011408
khiter_t k
14021409
bint use_na_value
1403-
1410+
bint non_null_na_value
14041411
if return_inverse:
14051412
labels = np.empty(n, dtype=np.intp)
14061413
use_na_value = na_value is not None
1414+
non_null_na_value = not checknull(na_value)
14071415

14081416
for i in range(n):
14091417
val = values[i]
14101418
hash(val)
14111419

14121420
if ignore_na and (
14131421
checknull(val)
1414-
or (use_na_value and val == na_value)
1422+
or (use_na_value and (
1423+
(non_null_na_value and val == na_value) or
1424+
(not non_null_na_value and is_matching_na(val, na_value))
1425+
)
1426+
)
14151427
):
14161428
# if missing values do not count as unique values (i.e. if
14171429
# ignore_na is True), skip the hashtable entry for them, and

pandas/core/arrays/_mixins.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -515,17 +515,14 @@ def _quantile(
515515
fill_value = self._internal_fill_value
516516

517517
res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
518-
519-
res_values = self._cast_quantile_result(res_values)
520-
return self._from_backing_data(res_values)
521-
522-
# TODO: see if we can share this with other dispatch-wrapping methods
523-
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
524-
"""
525-
Cast the result of quantile_with_mask to an appropriate dtype
526-
to pass to _from_backing_data in _quantile.
527-
"""
528-
return res_values
518+
if res_values.dtype == self._ndarray.dtype:
519+
return self._from_backing_data(res_values)
520+
else:
521+
# e.g. test_quantile_empty we are empty integer dtype and res_values
522+
# has floating dtype
523+
# TODO: technically __init__ isn't defined here.
524+
# Should we raise NotImplementedError and handle this on NumpyEA?
525+
return type(self)(res_values) # type: ignore[call-arg]
529526

530527
# ------------------------------------------------------------------------
531528
# numpy-like methods

pandas/core/arrays/categorical.py

-5
Original file line numberDiff line numberDiff line change
@@ -2475,11 +2475,6 @@ def unique(self) -> Self:
24752475
# pylint: disable=useless-parent-delegation
24762476
return super().unique()
24772477

2478-
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
2479-
# make sure we have correct itemsize for resulting codes
2480-
assert res_values.dtype == self._ndarray.dtype
2481-
return res_values
2482-
24832478
def equals(self, other: object) -> bool:
24842479
"""
24852480
Returns True if categorical arrays are equal.

pandas/core/arrays/numpy_.py

-3
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ def _from_sequence(
137137
result = result.copy()
138138
return cls(result)
139139

140-
def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray:
141-
return type(self)(arr)
142-
143140
# ------------------------------------------------------------------------
144141
# Data
145142

pandas/core/arrays/string_.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -657,11 +657,10 @@ def __arrow_array__(self, type=None):
657657
values[self.isna()] = None
658658
return pa.array(values, type=type, from_pandas=True)
659659

660-
def _values_for_factorize(self):
660+
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
661661
arr = self._ndarray.copy()
662-
mask = self.isna()
663-
arr[mask] = None
664-
return arr, None
662+
663+
return arr, self.dtype.na_value
665664

666665
def __setitem__(self, key, value) -> None:
667666
value = extract_array(value, extract_numpy=True)
@@ -871,8 +870,3 @@ def _from_sequence(
871870
if dtype is None:
872871
dtype = StringDtype(storage="python", na_value=np.nan)
873872
return super()._from_sequence(scalars, dtype=dtype, copy=copy)
874-
875-
def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics:
876-
# need to override NumpyExtensionArray._from_backing_data to ensure
877-
# we always preserve the dtype
878-
return NDArrayBacked._from_backing_data(self, arr)

pandas/tests/groupby/test_groupby_dropna.py

-3
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,6 @@ def test_groupby_dropna_with_multiindex_input(input_index, keys, series):
388388
tm.assert_equal(result, expected)
389389

390390

391-
@pytest.mark.xfail(
392-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
393-
)
394391
def test_groupby_nan_included():
395392
# GH 35646
396393
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}

pandas/tests/window/test_rolling.py

-6
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
import numpy as np
77
import pytest
88

9-
from pandas._config import using_string_dtype
10-
119
from pandas.compat import (
12-
HAS_PYARROW,
1310
IS64,
1411
is_platform_arm,
1512
is_platform_power,
@@ -1423,9 +1420,6 @@ def test_rolling_corr_timedelta_index(index, window):
14231420
tm.assert_almost_equal(result, expected)
14241421

14251422

1426-
@pytest.mark.xfail(
1427-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
1428-
)
14291423
def test_groupby_rolling_nan_included():
14301424
# GH 35542
14311425
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}

0 commit comments

Comments
 (0)