Skip to content

Commit a55ad54

Browse files
mroeschkeyehoshuadimarsky
authored andcommitted
ENH: Use pyarrow.compute for unique, dropna (pandas-dev#46725)
1 parent dbd11ef commit a55ad54

File tree

11 files changed

+337
-108
lines changed

11 files changed

+337
-108
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Other enhancements
114114
- :meth:`pd.concat` now raises when ``levels`` is given but ``keys`` is None (:issue:`46653`)
115115
- :meth:`pd.concat` now raises when ``levels`` contains duplicate values (:issue:`46653`)
116116
- Added ``numeric_only`` argument to :meth:`DataFrame.corr`, :meth:`DataFrame.corrwith`, and :meth:`DataFrame.cov` (:issue:`46560`)
117+
- A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`, :issue:`46725`)
117118
- Added ``validate`` argument to :meth:`DataFrame.join` (:issue:`46622`)
118119
- A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`)
119120
- Added ``numeric_only`` argument to :meth:`Resampler.sum`, :meth:`Resampler.prod`, :meth:`Resampler.min`, :meth:`Resampler.max`, :meth:`Resampler.first`, and :meth:`Resampler.last` (:issue:`46442`)

pandas/_testing/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@
6464
rands_array,
6565
randu_array,
6666
)
67-
from pandas._testing._warnings import assert_produces_warning # noqa:F401
67+
from pandas._testing._warnings import ( # noqa:F401
68+
assert_produces_warning,
69+
maybe_produces_warning,
70+
)
6871
from pandas._testing.asserters import ( # noqa:F401
6972
assert_almost_equal,
7073
assert_attr_equal,

pandas/_testing/_warnings.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from contextlib import contextmanager
3+
from contextlib import (
4+
contextmanager,
5+
nullcontext,
6+
)
47
import re
58
import sys
69
from typing import (
@@ -97,6 +100,16 @@ class for all warnings. To check that no warning is returned,
97100
)
98101

99102

103+
def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs):
104+
"""
105+
Return a context manager that possibly checks a warning based on the condition
106+
"""
107+
if condition:
108+
return assert_produces_warning(warning, **kwargs)
109+
else:
110+
return nullcontext()
111+
112+
100113
def _assert_caught_expected_warning(
101114
*,
102115
caught_warnings: Sequence[warnings.WarningMessage],

pandas/compat/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
pa_version_under3p0,
2424
pa_version_under4p0,
2525
pa_version_under5p0,
26+
pa_version_under6p0,
27+
pa_version_under7p0,
2628
)
2729

2830
PY39 = sys.version_info >= (3, 9)
@@ -150,4 +152,6 @@ def get_lzma_file():
150152
"pa_version_under3p0",
151153
"pa_version_under4p0",
152154
"pa_version_under5p0",
155+
"pa_version_under6p0",
156+
"pa_version_under7p0",
153157
]

pandas/core/arrays/arrow/array.py

+31
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
pa_version_under1p01,
1717
pa_version_under2p0,
1818
pa_version_under5p0,
19+
pa_version_under6p0,
1920
)
2021
from pandas.util._decorators import doc
2122

@@ -37,6 +38,8 @@
3738
import pyarrow as pa
3839
import pyarrow.compute as pc
3940

41+
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
42+
4043
if TYPE_CHECKING:
4144
from pandas import Series
4245

@@ -104,6 +107,20 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
104107
"""
105108
return type(self)(self._data)
106109

110+
def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
111+
"""
112+
Return ArrowExtensionArray without NA values.
113+
114+
Returns
115+
-------
116+
ArrowExtensionArray
117+
"""
118+
if pa_version_under6p0:
119+
fallback_performancewarning(version="6")
120+
return super().dropna()
121+
else:
122+
return type(self)(pc.drop_null(self._data))
123+
107124
@doc(ExtensionArray.factorize)
108125
def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
109126
encoded = self._data.dictionary_encode()
@@ -219,6 +236,20 @@ def take(
219236
indices_array[indices_array < 0] += len(self._data)
220237
return type(self)(self._data.take(indices_array))
221238

239+
def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
240+
"""
241+
Compute the ArrowExtensionArray of unique values.
242+
243+
Returns
244+
-------
245+
ArrowExtensionArray
246+
"""
247+
if pa_version_under2p0:
248+
fallback_performancewarning(version="2")
249+
return super().unique()
250+
else:
251+
return type(self)(pc.unique(self._data))
252+
222253
def value_counts(self, dropna: bool = True) -> Series:
223254
"""
224255
Return a Series containing counts of each unique value.

pandas/tests/arrays/string_/test_string.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
This module tests the functionality of StringArray and ArrowStringArray.
33
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
44
"""
5-
from contextlib import nullcontext
6-
75
import numpy as np
86
import pytest
97

@@ -18,13 +16,6 @@
1816
from pandas.core.arrays.string_arrow import ArrowStringArray
1917

2018

21-
def maybe_perf_warn(using_pyarrow):
22-
if using_pyarrow:
23-
return tm.assert_produces_warning(PerformanceWarning, match="Falling back")
24-
else:
25-
return nullcontext()
26-
27-
2819
@pytest.fixture
2920
def dtype(string_storage):
3021
return pd.StringDtype(storage=string_storage)
@@ -568,22 +559,30 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
568559
def test_isin(dtype, fixed_now_ts):
569560
s = pd.Series(["a", "b", None], dtype=dtype)
570561

571-
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
562+
with tm.maybe_produces_warning(
563+
PerformanceWarning, dtype == "pyarrow" and pa_version_under2p0
564+
):
572565
result = s.isin(["a", "c"])
573566
expected = pd.Series([True, False, False])
574567
tm.assert_series_equal(result, expected)
575568

576-
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
569+
with tm.maybe_produces_warning(
570+
PerformanceWarning, dtype == "pyarrow" and pa_version_under2p0
571+
):
577572
result = s.isin(["a", pd.NA])
578573
expected = pd.Series([True, False, True])
579574
tm.assert_series_equal(result, expected)
580575

581-
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
576+
with tm.maybe_produces_warning(
577+
PerformanceWarning, dtype == "pyarrow" and pa_version_under2p0
578+
):
582579
result = s.isin([])
583580
expected = pd.Series([False, False, False])
584581
tm.assert_series_equal(result, expected)
585582

586-
with maybe_perf_warn(dtype == "pyarrow" and pa_version_under2p0):
583+
with tm.maybe_produces_warning(
584+
PerformanceWarning, dtype == "pyarrow" and pa_version_under2p0
585+
):
587586
result = s.isin(["a", fixed_now_ts])
588587
expected = pd.Series([True, False, False])
589588
tm.assert_series_equal(result, expected)

pandas/tests/base/test_unique.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.compat import pa_version_under2p0
5+
from pandas.errors import PerformanceWarning
6+
47
from pandas.core.dtypes.common import is_datetime64tz_dtype
58

69
import pandas as pd
@@ -12,7 +15,11 @@
1215
def test_unique(index_or_series_obj):
1316
obj = index_or_series_obj
1417
obj = np.repeat(obj, range(1, len(obj) + 1))
15-
result = obj.unique()
18+
with tm.maybe_produces_warning(
19+
PerformanceWarning,
20+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
21+
):
22+
result = obj.unique()
1623

1724
# dict.fromkeys preserves the order
1825
unique_values = list(dict.fromkeys(obj.values))
@@ -50,7 +57,11 @@ def test_unique_null(null_obj, index_or_series_obj):
5057
klass = type(obj)
5158
repeated_values = np.repeat(values, range(1, len(values) + 1))
5259
obj = klass(repeated_values, dtype=obj.dtype)
53-
result = obj.unique()
60+
with tm.maybe_produces_warning(
61+
PerformanceWarning,
62+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
63+
):
64+
result = obj.unique()
5465

5566
unique_values_raw = dict.fromkeys(obj.values)
5667
# because np.nan == np.nan is False, but None == None is True
@@ -75,7 +86,11 @@ def test_unique_null(null_obj, index_or_series_obj):
7586
def test_nunique(index_or_series_obj):
7687
obj = index_or_series_obj
7788
obj = np.repeat(obj, range(1, len(obj) + 1))
78-
expected = len(obj.unique())
89+
with tm.maybe_produces_warning(
90+
PerformanceWarning,
91+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
92+
):
93+
expected = len(obj.unique())
7994
assert obj.nunique(dropna=False) == expected
8095

8196

@@ -99,9 +114,21 @@ def test_nunique_null(null_obj, index_or_series_obj):
99114
assert obj.nunique() == len(obj.categories)
100115
assert obj.nunique(dropna=False) == len(obj.categories) + 1
101116
else:
102-
num_unique_values = len(obj.unique())
103-
assert obj.nunique() == max(0, num_unique_values - 1)
104-
assert obj.nunique(dropna=False) == max(0, num_unique_values)
117+
with tm.maybe_produces_warning(
118+
PerformanceWarning,
119+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
120+
):
121+
num_unique_values = len(obj.unique())
122+
with tm.maybe_produces_warning(
123+
PerformanceWarning,
124+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
125+
):
126+
assert obj.nunique() == max(0, num_unique_values - 1)
127+
with tm.maybe_produces_warning(
128+
PerformanceWarning,
129+
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]",
130+
):
131+
assert obj.nunique(dropna=False) == max(0, num_unique_values)
105132

106133

107134
@pytest.mark.single_cpu

pandas/tests/extension/test_string.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
import numpy as np
1919
import pytest
2020

21+
from pandas.compat import pa_version_under6p0
22+
from pandas.errors import PerformanceWarning
23+
2124
import pandas as pd
25+
import pandas._testing as tm
2226
from pandas.core.arrays import ArrowStringArray
2327
from pandas.core.arrays.string_ import StringDtype
2428
from pandas.tests.extension import base
@@ -139,7 +143,14 @@ class TestIndex(base.BaseIndexTests):
139143

140144

141145
class TestMissing(base.BaseMissingTests):
142-
pass
146+
def test_dropna_array(self, data_missing):
147+
with tm.maybe_produces_warning(
148+
PerformanceWarning,
149+
pa_version_under6p0 and data_missing.dtype.storage == "pyarrow",
150+
):
151+
result = data_missing.dropna()
152+
expected = data_missing[[1]]
153+
self.assert_extension_array_equal(result, expected)
143154

144155

145156
class TestNoReduce(base.BaseNoReduceTests):

pandas/tests/indexes/test_common.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas.compat import IS64
11+
from pandas.compat import (
12+
IS64,
13+
pa_version_under2p0,
14+
)
1215

1316
from pandas.core.dtypes.common import is_integer_dtype
1417

@@ -395,7 +398,10 @@ def test_astype_preserves_name(self, index, dtype):
395398

396399
try:
397400
# Some of these conversions cannot succeed so we use a try / except
398-
with tm.assert_produces_warning(warn):
401+
with tm.assert_produces_warning(
402+
warn,
403+
raise_on_extra_warnings=not pa_version_under2p0,
404+
):
399405
result = index.astype(dtype)
400406
except (ValueError, TypeError, NotImplementedError, SystemError):
401407
return

0 commit comments

Comments
 (0)