Skip to content

ENH/TST: Add BaseGroupbyTests tests for ArrowExtensionArray #47515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 3, 2022
Merged
5 changes: 3 additions & 2 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,11 @@ def isna_all(arr: ArrayLike) -> bool:
chunk_len = max(total_len // 40, 1000)

dtype = arr.dtype
if dtype.kind == "f":
is_np_missing_value = na_value_for_dtype(dtype) is not libmissing.NA
if dtype.kind == "f" and is_np_missing_value:
checker = nan_checker

elif dtype.kind in ["m", "M"] or dtype.type is Period:
elif (dtype.kind in ["m", "M"] or dtype.type is Period) and is_np_missing_value:
# error: Incompatible types in assignment (expression has type
# "Callable[[Any], Any]", variable has type "ufunc")
checker = lambda x: np.asarray(x.view("i8")) == iNaT # type: ignore[assignment]
Expand Down
178 changes: 176 additions & 2 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ def all_data(request, data, data_missing):
return data_missing


@pytest.fixture
def data_for_grouping(dtype):
"""
Data for factorization, grouping, and unique tests.

Expected to be like [B, B, NA, NA, A, A, B, C]

Where A < B < C and NA is missing
"""
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
A = False
B = True
C = True
elif pa.types.is_floating(pa_dtype):
A = -1.1
B = 0.0
C = 1.1
elif pa.types.is_signed_integer(pa_dtype):
A = -1
B = 0
C = 1
elif pa.types.is_unsigned_integer(pa_dtype):
A = 0
B = 1
C = 10
elif pa.types.is_date(pa_dtype):
A = date(1999, 12, 31)
B = date(2010, 1, 1)
C = date(2022, 1, 1)
elif pa.types.is_timestamp(pa_dtype):
A = datetime(1999, 1, 1, 1, 1, 1, 1)
B = datetime(2020, 1, 1)
C = datetime(2020, 1, 1, 1)
elif pa.types.is_duration(pa_dtype):
A = timedelta(-1)
B = timedelta(0)
C = timedelta(1, 4)
elif pa.types.is_time(pa_dtype):
A = time(0, 0)
B = time(0, 12)
C = time(12, 12)
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)


@pytest.fixture
def na_value():
"""The scalar missing value for this type. Default 'None'"""
Expand Down Expand Up @@ -219,6 +266,133 @@ def test_loc_iloc_frame_single_dtype(self, request, using_array_manager, data):
super().test_loc_iloc_frame_single_dtype(data)


class TestBaseGroupby(base.BaseGroupbyTests):
def test_groupby_agg_extension(self, data_for_grouping, request):
tz = getattr(data_for_grouping.dtype.pyarrow_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}."
)
)
super().test_groupby_agg_extension(data_for_grouping)

def test_groupby_extension_no_sort(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_no_sort(data_for_grouping)

def test_groupby_extension_transform(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
super().test_groupby_extension_transform(data_for_grouping)

def test_groupby_extension_apply(
self, data_for_grouping, groupby_apply_op, request
):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
# Is there a better way to get the "series" ID for groupby_apply_op?
is_series = "series" in request.node.nodeid
is_object = "object" in request.node.nodeid
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
if is_object:
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
elif not is_series:
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)

def test_in_numeric_groupby(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason="ArrowExtensionArray doesn't support .sum() yet.",
)
)
super().test_in_numeric_groupby(data_for_grouping)

@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=ValueError,
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif as_index is True and (
pa.types.is_date(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_groupby_extension_agg(as_index, data_for_grouping)


class TestBaseDtype(base.BaseDtypeTests):
def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
Expand Down Expand Up @@ -736,8 +910,8 @@ def test_setitem_slice_array(self, data, request):
def test_setitem_with_expansion_dataframe_column(
self, data, full_indexer, using_array_manager, request
):
# Is there a way to get the full_indexer id "null_slice"?
is_null_slice = full_indexer(pd.Series(dtype=object)) == slice(None)
# Is there a better way to get the full_indexer id "null_slice"?
is_null_slice = "null_slice" in request.node.nodeid
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC") and not is_null_slice:
request.node.add_marker(
Expand Down