Skip to content

Commit b81f431

Browse files
authored
ENH/TST: Add BaseGroupbyTests tests for ArrowExtensionArray (#47515)
1 parent f538568 commit b81f431

File tree

2 files changed

+183
-4
lines changed

2 files changed

+183
-4
lines changed

pandas/core/dtypes/missing.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from pandas.core.dtypes.dtypes import (
4343
CategoricalDtype,
44+
DatetimeTZDtype,
4445
ExtensionDtype,
4546
IntervalDtype,
4647
PeriodDtype,
@@ -754,10 +755,14 @@ def isna_all(arr: ArrayLike) -> bool:
754755
chunk_len = max(total_len // 40, 1000)
755756

756757
dtype = arr.dtype
757-
if dtype.kind == "f":
758+
if dtype.kind == "f" and isinstance(dtype, np.dtype):
758759
checker = nan_checker
759760

760-
elif dtype.kind in ["m", "M"] or dtype.type is Period:
761+
elif (
762+
(isinstance(dtype, np.dtype) and dtype.kind in ["m", "M"])
763+
or isinstance(dtype, DatetimeTZDtype)
764+
or dtype.type is Period
765+
):
761766
# error: Incompatible types in assignment (expression has type
762767
# "Callable[[Any], Any]", variable has type "ufunc")
763768
checker = lambda x: np.asarray(x.view("i8")) == iNaT # type: ignore[assignment]

pandas/tests/extension/test_arrow.py

+176-2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,53 @@ def all_data(request, data, data_missing):
106106
return data_missing
107107

108108

109+
@pytest.fixture
110+
def data_for_grouping(dtype):
111+
"""
112+
Data for factorization, grouping, and unique tests.
113+
114+
Expected to be like [B, B, NA, NA, A, A, B, C]
115+
116+
Where A < B < C and NA is missing
117+
"""
118+
pa_dtype = dtype.pyarrow_dtype
119+
if pa.types.is_boolean(pa_dtype):
120+
A = False
121+
B = True
122+
C = True
123+
elif pa.types.is_floating(pa_dtype):
124+
A = -1.1
125+
B = 0.0
126+
C = 1.1
127+
elif pa.types.is_signed_integer(pa_dtype):
128+
A = -1
129+
B = 0
130+
C = 1
131+
elif pa.types.is_unsigned_integer(pa_dtype):
132+
A = 0
133+
B = 1
134+
C = 10
135+
elif pa.types.is_date(pa_dtype):
136+
A = date(1999, 12, 31)
137+
B = date(2010, 1, 1)
138+
C = date(2022, 1, 1)
139+
elif pa.types.is_timestamp(pa_dtype):
140+
A = datetime(1999, 1, 1, 1, 1, 1, 1)
141+
B = datetime(2020, 1, 1)
142+
C = datetime(2020, 1, 1, 1)
143+
elif pa.types.is_duration(pa_dtype):
144+
A = timedelta(-1)
145+
B = timedelta(0)
146+
C = timedelta(1, 4)
147+
elif pa.types.is_time(pa_dtype):
148+
A = time(0, 0)
149+
B = time(0, 12)
150+
C = time(12, 12)
151+
else:
152+
raise NotImplementedError
153+
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
154+
155+
109156
@pytest.fixture
110157
def na_value():
111158
"""The scalar missing value for this type. Default 'None'"""
@@ -219,6 +266,133 @@ def test_loc_iloc_frame_single_dtype(self, request, using_array_manager, data):
219266
super().test_loc_iloc_frame_single_dtype(data)
220267

221268

269+
class TestBaseGroupby(base.BaseGroupbyTests):
270+
def test_groupby_agg_extension(self, data_for_grouping, request):
271+
tz = getattr(data_for_grouping.dtype.pyarrow_dtype, "tz", None)
272+
if pa_version_under2p0 and tz not in (None, "UTC"):
273+
request.node.add_marker(
274+
pytest.mark.xfail(
275+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}."
276+
)
277+
)
278+
super().test_groupby_agg_extension(data_for_grouping)
279+
280+
def test_groupby_extension_no_sort(self, data_for_grouping, request):
281+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
282+
if pa.types.is_boolean(pa_dtype):
283+
request.node.add_marker(
284+
pytest.mark.xfail(
285+
reason=f"{pa_dtype} only has 2 unique possible values",
286+
)
287+
)
288+
elif pa.types.is_duration(pa_dtype):
289+
request.node.add_marker(
290+
pytest.mark.xfail(
291+
raises=pa.ArrowNotImplementedError,
292+
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
293+
)
294+
)
295+
elif pa.types.is_date(pa_dtype) or (
296+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
297+
):
298+
request.node.add_marker(
299+
pytest.mark.xfail(
300+
raises=AttributeError,
301+
reason="GH 34986",
302+
)
303+
)
304+
super().test_groupby_extension_no_sort(data_for_grouping)
305+
306+
def test_groupby_extension_transform(self, data_for_grouping, request):
307+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
308+
if pa.types.is_boolean(pa_dtype):
309+
request.node.add_marker(
310+
pytest.mark.xfail(
311+
reason=f"{pa_dtype} only has 2 unique possible values",
312+
)
313+
)
314+
elif pa.types.is_duration(pa_dtype):
315+
request.node.add_marker(
316+
pytest.mark.xfail(
317+
raises=pa.ArrowNotImplementedError,
318+
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
319+
)
320+
)
321+
super().test_groupby_extension_transform(data_for_grouping)
322+
323+
def test_groupby_extension_apply(
324+
self, data_for_grouping, groupby_apply_op, request
325+
):
326+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
327+
# Is there a better way to get the "series" ID for groupby_apply_op?
328+
is_series = "series" in request.node.nodeid
329+
is_object = "object" in request.node.nodeid
330+
if pa.types.is_duration(pa_dtype):
331+
request.node.add_marker(
332+
pytest.mark.xfail(
333+
raises=pa.ArrowNotImplementedError,
334+
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
335+
)
336+
)
337+
elif pa.types.is_date(pa_dtype) or (
338+
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
339+
):
340+
if is_object:
341+
request.node.add_marker(
342+
pytest.mark.xfail(
343+
raises=TypeError,
344+
reason="GH 47514: _concat_datetime expects axis arg.",
345+
)
346+
)
347+
elif not is_series:
348+
request.node.add_marker(
349+
pytest.mark.xfail(
350+
raises=AttributeError,
351+
reason="GH 34986",
352+
)
353+
)
354+
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
355+
356+
def test_in_numeric_groupby(self, data_for_grouping, request):
357+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
358+
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
359+
request.node.add_marker(
360+
pytest.mark.xfail(
361+
reason="ArrowExtensionArray doesn't support .sum() yet.",
362+
)
363+
)
364+
super().test_in_numeric_groupby(data_for_grouping)
365+
366+
@pytest.mark.parametrize("as_index", [True, False])
367+
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
368+
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
369+
if pa.types.is_boolean(pa_dtype):
370+
request.node.add_marker(
371+
pytest.mark.xfail(
372+
raises=ValueError,
373+
reason=f"{pa_dtype} only has 2 unique possible values",
374+
)
375+
)
376+
elif pa.types.is_duration(pa_dtype):
377+
request.node.add_marker(
378+
pytest.mark.xfail(
379+
raises=pa.ArrowNotImplementedError,
380+
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
381+
)
382+
)
383+
elif as_index is True and (
384+
pa.types.is_date(pa_dtype)
385+
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
386+
):
387+
request.node.add_marker(
388+
pytest.mark.xfail(
389+
raises=AttributeError,
390+
reason="GH 34986",
391+
)
392+
)
393+
super().test_groupby_extension_agg(as_index, data_for_grouping)
394+
395+
222396
class TestBaseDtype(base.BaseDtypeTests):
223397
def test_construct_from_string_own_name(self, dtype, request):
224398
pa_dtype = dtype.pyarrow_dtype
@@ -736,8 +910,8 @@ def test_setitem_slice_array(self, data, request):
736910
def test_setitem_with_expansion_dataframe_column(
737911
self, data, full_indexer, using_array_manager, request
738912
):
739-
# Is there a way to get the full_indexer id "null_slice"?
740-
is_null_slice = full_indexer(pd.Series(dtype=object)) == slice(None)
913+
# Is there a better way to get the full_indexer id "null_slice"?
914+
is_null_slice = "null_slice" in request.node.nodeid
741915
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
742916
if pa_version_under2p0 and tz not in (None, "UTC") and not is_null_slice:
743917
request.node.add_marker(

0 commit comments

Comments
 (0)