-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: Use pyarrow.compute for unique, dropna #46725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
4c010aa
def3510
315f59a
3700867
2dc5918
ebf62e8
ea4e9e9
e254528
e2a093f
ecefbee
7a5d4fb
527c0b7
de815df
40bd857
26b4cdb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
pa_version_under1p01, | ||
pa_version_under2p0, | ||
pa_version_under5p0, | ||
pa_version_under6p0, | ||
) | ||
from pandas.util._decorators import doc | ||
|
||
|
@@ -37,6 +38,8 @@ | |
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
|
||
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning | ||
|
||
if TYPE_CHECKING: | ||
from pandas import Series | ||
|
||
|
@@ -104,6 +107,20 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: | |
""" | ||
return type(self)(self._data) | ||
|
||
def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: | ||
""" | ||
Return ArrowExtensionArray without NA values. | ||
|
||
Returns | ||
------- | ||
ArrowExtensionArray | ||
""" | ||
if pa_version_under6p0: | ||
fallback_performancewarning(version="6") | ||
return super().dropna() | ||
else: | ||
return type(self)(pc.drop_null(self._data)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we don't actually dispatch to this method from pandas? I wonder whether there would be any performance gain if we refactored to call this array method instead? (from Series.dropna for example) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm not exactly sure what you mean here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see now. Yeah hooking this up to dropna might be a good idea in a future PR |
||
|
||
@doc(ExtensionArray.factorize) | ||
def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]: | ||
encoded = self._data.dictionary_encode() | ||
|
@@ -219,6 +236,20 @@ def take( | |
indices_array[indices_array < 0] += len(self._data) | ||
return type(self)(self._data.take(indices_array)) | ||
|
||
def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: | ||
""" | ||
Compute the ArrowExtensionArray of unique values. | ||
|
||
Returns | ||
------- | ||
ArrowExtensionArray | ||
""" | ||
if pa_version_under2p0: | ||
fallback_performancewarning(version="2") | ||
return super().unique() | ||
else: | ||
return type(self)(pc.unique(self._data)) | ||
|
||
def value_counts(self, dropna: bool = True) -> Series: | ||
""" | ||
Return a Series containing counts of each unique value. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pandas.compat import pa_version_under2p0 | ||
from pandas.errors import PerformanceWarning | ||
|
||
from pandas.core.dtypes.common import is_datetime64tz_dtype | ||
|
||
import pandas as pd | ||
|
@@ -12,7 +15,11 @@ | |
def test_unique(index_or_series_obj): | ||
obj = index_or_series_obj | ||
obj = np.repeat(obj, range(1, len(obj) + 1)) | ||
result = obj.unique() | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that with a pyarrow backed StringSrray, we are only testing Index here, not Series? Also don't need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Looks so based on the fixture in
Fixed the comparison |
||
): | ||
result = obj.unique() | ||
|
||
# dict.fromkeys preserves the order | ||
unique_values = list(dict.fromkeys(obj.values)) | ||
|
@@ -50,7 +57,11 @@ def test_unique_null(null_obj, index_or_series_obj): | |
klass = type(obj) | ||
repeated_values = np.repeat(values, range(1, len(values) + 1)) | ||
obj = klass(repeated_values, dtype=obj.dtype) | ||
result = obj.unique() | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
): | ||
result = obj.unique() | ||
|
||
unique_values_raw = dict.fromkeys(obj.values) | ||
# because np.nan == np.nan is False, but None == None is True | ||
|
@@ -75,7 +86,11 @@ def test_unique_null(null_obj, index_or_series_obj): | |
def test_nunique(index_or_series_obj): | ||
obj = index_or_series_obj | ||
obj = np.repeat(obj, range(1, len(obj) + 1)) | ||
expected = len(obj.unique()) | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
): | ||
expected = len(obj.unique()) | ||
assert obj.nunique(dropna=False) == expected | ||
|
||
|
||
|
@@ -99,9 +114,21 @@ def test_nunique_null(null_obj, index_or_series_obj): | |
assert obj.nunique() == len(obj.categories) | ||
assert obj.nunique(dropna=False) == len(obj.categories) + 1 | ||
else: | ||
num_unique_values = len(obj.unique()) | ||
assert obj.nunique() == max(0, num_unique_values - 1) | ||
assert obj.nunique(dropna=False) == max(0, num_unique_values) | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
): | ||
num_unique_values = len(obj.unique()) | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
): | ||
assert obj.nunique() == max(0, num_unique_values - 1) | ||
with tm.maybe_produces_warning( | ||
PerformanceWarning, | ||
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]", | ||
): | ||
assert obj.nunique(dropna=False) == max(0, num_unique_values) | ||
|
||
|
||
@pytest.mark.single_cpu | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be guarded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so.
_arrow_utils
doesn't guardimport pyarrow