Skip to content

ENH: Use pyarrow.compute for unique, dropna #46725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
pa_version_under3p0,
pa_version_under4p0,
pa_version_under5p0,
pa_version_under6p0,
pa_version_under7p0,
)

PY39 = sys.version_info >= (3, 9)
Expand Down Expand Up @@ -150,4 +152,6 @@ def get_lzma_file():
"pa_version_under3p0",
"pa_version_under4p0",
"pa_version_under5p0",
"pa_version_under6p0",
"pa_version_under7p0",
]
31 changes: 31 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
pa_version_under1p01,
pa_version_under2p0,
pa_version_under5p0,
pa_version_under6p0,
)
from pandas.util._decorators import doc

Expand All @@ -37,6 +38,8 @@
import pyarrow as pa
import pyarrow.compute as pc

from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be guarded?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. _arrow_utils doesn't guard import pyarrow


if TYPE_CHECKING:
from pandas import Series

Expand Down Expand Up @@ -104,6 +107,20 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
return type(self)(self._data)

def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
Return ArrowExtensionArray without NA values.
Returns
-------
valid : ArrowExtensionArray
"""
if pa_version_under6p0:
fallback_performancewarning(version="6")
return super().dropna()
else:
return type(self)(pc.drop_null(self._data))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we don't actually dispatch to this method from pandas?

I wonder whether there would be any performance gain if we refactored to call this array method instead? (from Series.dropna for example)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not exactly sure what you mean here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see now. Yeah hooking this up to dropna might be a good idea in a future PR


@doc(ExtensionArray.factorize)
def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
encoded = self._data.dictionary_encode()
Expand Down Expand Up @@ -219,6 +236,20 @@ def take(
indices_array[indices_array < 0] += len(self._data)
return type(self)(self._data.take(indices_array))

def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
Compute the ArrowExtensionArray of unique values.
Returns
-------
uniques : ArrowExtensionArray
"""
if pa_version_under2p0:
fallback_performancewarning(version="2")
return super().unique()
else:
return type(self)(pc.unique(self._data))

def value_counts(self, dropna: bool = True) -> Series:
"""
Return a Series containing counts of each unique value.
Expand Down
42 changes: 36 additions & 6 deletions pandas/tests/base/test_unique.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from contextlib import nullcontext

import numpy as np
import pytest

from pandas.compat import pa_version_under2p0
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.common import is_datetime64tz_dtype

import pandas as pd
Expand All @@ -9,10 +14,20 @@
from pandas.tests.base.common import allow_na_ops


def maybe_perf_warn(using_pyarrow):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally move to the _test_decorators.py (or similar) e.g. this is a general testing function.

if using_pyarrow:
return tm.assert_produces_warning(PerformanceWarning, match="Falling back")
else:
return nullcontext()


def test_unique(index_or_series_obj):
obj = index_or_series_obj
obj = np.repeat(obj, range(1, len(obj) + 1))
result = obj.unique()
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
result = obj.unique()

# dict.fromkeys preserves the order
unique_values = list(dict.fromkeys(obj.values))
Expand Down Expand Up @@ -50,7 +65,10 @@ def test_unique_null(null_obj, index_or_series_obj):
klass = type(obj)
repeated_values = np.repeat(values, range(1, len(values) + 1))
obj = klass(repeated_values, dtype=obj.dtype)
result = obj.unique()
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
result = obj.unique()

unique_values_raw = dict.fromkeys(obj.values)
# because np.nan == np.nan is False, but None == None is True
Expand All @@ -75,7 +93,10 @@ def test_unique_null(null_obj, index_or_series_obj):
def test_nunique(index_or_series_obj):
obj = index_or_series_obj
obj = np.repeat(obj, range(1, len(obj) + 1))
expected = len(obj.unique())
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
expected = len(obj.unique())
assert obj.nunique(dropna=False) == expected


Expand All @@ -99,9 +120,18 @@ def test_nunique_null(null_obj, index_or_series_obj):
assert obj.nunique() == len(obj.categories)
assert obj.nunique(dropna=False) == len(obj.categories) + 1
else:
num_unique_values = len(obj.unique())
assert obj.nunique() == max(0, num_unique_values - 1)
assert obj.nunique(dropna=False) == max(0, num_unique_values)
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
num_unique_values = len(obj.unique())
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
assert obj.nunique() == max(0, num_unique_values - 1)
with maybe_perf_warn(
pa_version_under2p0 and str(index_or_series_obj.dtype) == "string[pyarrow]"
):
assert obj.nunique(dropna=False) == max(0, num_unique_values)


@pytest.mark.single_cpu
Expand Down
20 changes: 19 additions & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@
be added to the array-specific tests in `pandas/tests/arrays/`.

"""
from contextlib import nullcontext
import string

import numpy as np
import pytest

from pandas.compat import pa_version_under6p0
from pandas.errors import PerformanceWarning

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import ArrowStringArray
from pandas.core.arrays.string_ import StringDtype
from pandas.tests.extension import base


def maybe_perf_warn(using_pyarrow):
if using_pyarrow:
return tm.assert_produces_warning(PerformanceWarning, match="Falling back")
else:
return nullcontext()


def split_array(arr):
if arr.dtype.storage != "pyarrow":
pytest.skip("only applicable for pyarrow chunked array n/a")
Expand Down Expand Up @@ -139,7 +151,13 @@ class TestIndex(base.BaseIndexTests):


class TestMissing(base.BaseMissingTests):
pass
def test_dropna_array(self, data_missing):
with maybe_perf_warn(
pa_version_under6p0 and data_missing.dtype.storage == "pyarrow"
):
result = data_missing.dropna()
expected = data_missing[[1]]
self.assert_extension_array_equal(result, expected)


class TestNoReduce(base.BaseNoReduceTests):
Expand Down
9 changes: 7 additions & 2 deletions pandas/tests/indexes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import numpy as np
import pytest

from pandas.compat import IS64
from pandas.compat import (
IS64,
pa_version_under2p0,
)

from pandas.core.dtypes.common import is_integer_dtype

Expand Down Expand Up @@ -395,7 +398,9 @@ def test_astype_preserves_name(self, index, dtype):

try:
# Some of these conversions cannot succeed so we use a try / except
with tm.assert_produces_warning(warn):
with tm.assert_produces_warning(
warn, raise_on_extra_warnings=not pa_version_under2p0
):
result = index.astype(dtype)
except (ValueError, TypeError, NotImplementedError, SystemError):
return
Expand Down