Skip to content

Commit 56767a0

Browse files
jbrockmendelmroeschke
authored andcommitted
REF: use updated patterns in pyarrow TestReduce (pandas-dev#54438)
* REF: use updated patterns in pyarrow TestReduce * Remove unnecessary skip * mypy fixup * update xfail
1 parent b71a817 commit 56767a0

File tree

1 file changed

+53
-61
lines changed

1 file changed

+53
-61
lines changed

pandas/tests/extension/test_arrow.py

+53-61
Original file line numberDiff line numberDiff line change
@@ -412,50 +412,25 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
412412

413413
class TestReduce(base.BaseReduceTests):
414414
def _supports_reduction(self, obj, op_name: str) -> bool:
415-
return True
416-
417-
def check_reduce(self, ser, op_name, skipna):
418-
pa_dtype = ser.dtype.pyarrow_dtype
419-
if op_name == "count":
420-
result = getattr(ser, op_name)()
421-
else:
422-
result = getattr(ser, op_name)(skipna=skipna)
423-
if pa.types.is_boolean(pa_dtype):
424-
# Can't convert if ser contains NA
425-
pytest.skip(
426-
"pandas boolean data with NA does not fully support all reductions"
427-
)
428-
elif pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
429-
ser = ser.astype("Float64")
430-
if op_name == "count":
431-
expected = getattr(ser, op_name)()
432-
else:
433-
expected = getattr(ser, op_name)(skipna=skipna)
434-
tm.assert_almost_equal(result, expected)
435-
436-
@pytest.mark.parametrize("skipna", [True, False])
437-
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
438-
pa_dtype = data.dtype.pyarrow_dtype
439-
opname = all_numeric_reductions
440-
441-
ser = pd.Series(data)
442-
443-
should_work = True
444-
if pa.types.is_temporal(pa_dtype) and opname in [
415+
dtype = tm.get_dtype(obj)
416+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has
417+
# no attribute "pyarrow_dtype"
418+
pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr]
419+
if pa.types.is_temporal(pa_dtype) and op_name in [
445420
"sum",
446421
"var",
447422
"skew",
448423
"kurt",
449424
"prod",
450425
]:
451-
if pa.types.is_duration(pa_dtype) and opname in ["sum"]:
426+
if pa.types.is_duration(pa_dtype) and op_name in ["sum"]:
452427
# summing timedeltas is one case that *is* well-defined
453428
pass
454429
else:
455-
should_work = False
430+
return False
456431
elif (
457432
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
458-
) and opname in [
433+
) and op_name in [
459434
"sum",
460435
"mean",
461436
"median",
@@ -466,16 +441,40 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, reque
466441
"skew",
467442
"kurt",
468443
]:
469-
should_work = False
444+
return False
470445

471-
if not should_work:
472-
# matching the non-pyarrow versions, these operations *should* not
473-
# work for these dtypes
474-
msg = f"does not support reduction '{opname}'"
475-
with pytest.raises(TypeError, match=msg):
476-
getattr(ser, opname)(skipna=skipna)
446+
if (
447+
pa.types.is_temporal(pa_dtype)
448+
and not pa.types.is_duration(pa_dtype)
449+
and op_name in ["any", "all"]
450+
):
451+
# xref GH#34479 we support this in our non-pyarrow datetime64 dtypes,
452+
# but it isn't obvious we _should_. For now, we keep the pyarrow
453+
# behavior which does not support this.
454+
return False
477455

478-
return
456+
return True
457+
458+
def check_reduce(self, ser, op_name, skipna):
459+
pa_dtype = ser.dtype.pyarrow_dtype
460+
if op_name == "count":
461+
result = getattr(ser, op_name)()
462+
else:
463+
result = getattr(ser, op_name)(skipna=skipna)
464+
465+
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
466+
ser = ser.astype("Float64")
467+
# TODO: in the opposite case, aren't we testing... nothing?
468+
if op_name == "count":
469+
expected = getattr(ser, op_name)()
470+
else:
471+
expected = getattr(ser, op_name)(skipna=skipna)
472+
tm.assert_almost_equal(result, expected)
473+
474+
@pytest.mark.parametrize("skipna", [True, False])
475+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
476+
dtype = data.dtype
477+
pa_dtype = dtype.pyarrow_dtype
479478

480479
xfail_mark = pytest.mark.xfail(
481480
raises=TypeError,
@@ -484,15 +483,21 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, reque
484483
f"pyarrow={pa.__version__} for {pa_dtype}"
485484
),
486485
)
487-
if all_numeric_reductions in {"skew", "kurt"}:
486+
if all_numeric_reductions in {"skew", "kurt"} and (
487+
dtype._is_numeric or dtype.kind == "b"
488+
):
488489
request.node.add_marker(xfail_mark)
489490
elif (
490491
all_numeric_reductions in {"var", "std", "median"}
491492
and pa_version_under7p0
492493
and pa.types.is_decimal(pa_dtype)
493494
):
494495
request.node.add_marker(xfail_mark)
495-
elif all_numeric_reductions == "sem" and pa_version_under8p0:
496+
elif (
497+
all_numeric_reductions == "sem"
498+
and pa_version_under8p0
499+
and (dtype._is_numeric or pa.types.is_temporal(pa_dtype))
500+
):
496501
request.node.add_marker(xfail_mark)
497502

498503
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
@@ -521,21 +526,7 @@ def test_reduce_series_boolean(
521526
# but have not yet decided.
522527
request.node.add_marker(xfail_mark)
523528

524-
op_name = all_boolean_reductions
525-
ser = pd.Series(data)
526-
527-
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
528-
# xref GH#34479 we support this in our non-pyarrow datetime64 dtypes,
529-
# but it isn't obvious we _should_. For now, we keep the pyarrow
530-
# behavior which does not support this.
531-
532-
with pytest.raises(TypeError, match="does not support reduction"):
533-
getattr(ser, op_name)(skipna=skipna)
534-
535-
return
536-
537-
result = getattr(ser, op_name)(skipna=skipna)
538-
assert result is (op_name == "any")
529+
return super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)
539530

540531
def _get_expected_reduction_dtype(self, arr, op_name: str):
541532
if op_name in ["max", "min"]:
@@ -556,11 +547,12 @@ def _get_expected_reduction_dtype(self, arr, op_name: str):
556547
return cmp_dtype
557548

558549
@pytest.mark.parametrize("skipna", [True, False])
559-
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
550+
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
560551
op_name = all_numeric_reductions
561552
if op_name == "skew":
562-
assert not hasattr(data, op_name)
563-
return
553+
if data.dtype._is_numeric:
554+
mark = pytest.mark.xfail(reason="skew not implemented")
555+
request.node.add_marker(mark)
564556
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
565557

566558
@pytest.mark.parametrize("typ", ["int64", "uint64", "float64"])

0 commit comments

Comments
 (0)