Skip to content

Commit f5405b5

Browse files
authored
TST: Test ArrowExtensionArray with decimal types (#50964)
* TST: Test ArrowExtensionArray with decimal types * Version compat * Add other xfails based on min version * fix test * fix typo * another typo * only for skipna * Add comment * Fix * undo comments * Bump version condition * Skip masked indexing engine for decimal * Some merge stuff * Remove imaginary test * Fix condition * Fix another test * Update condition
1 parent 3f2b18a commit f5405b5

File tree

6 files changed

+149
-19
lines changed

6 files changed

+149
-19
lines changed

pandas/_testing/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
FLOAT_PYARROW_DTYPES_STR_REPR = [
216216
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
217217
]
218+
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
218219
STRING_PYARROW_DTYPES = [pa.string()]
219220
BINARY_PYARROW_DTYPES = [pa.binary()]
220221

@@ -239,6 +240,7 @@
239240
ALL_PYARROW_DTYPES = (
240241
ALL_INT_PYARROW_DTYPES
241242
+ FLOAT_PYARROW_DTYPES
243+
+ DECIMAL_PYARROW_DTYPES
242244
+ STRING_PYARROW_DTYPES
243245
+ BINARY_PYARROW_DTYPES
244246
+ TIME_PYARROW_DTYPES

pandas/core/arrays/arrow/array.py

+1
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
10981098
pa.types.is_integer(pa_type)
10991099
or pa.types.is_floating(pa_type)
11001100
or pa.types.is_duration(pa_type)
1101+
or pa.types.is_decimal(pa_type)
11011102
):
11021103
# pyarrow only supports any/all for boolean dtype, we allow
11031104
# for other dtypes, matching our non-pyarrow behavior

pandas/core/arrays/arrow/dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
201201
try:
202202
pa_dtype = pa.type_for_alias(base_type)
203203
except ValueError as err:
204-
has_parameters = re.search(r"\[.*\]", base_type)
204+
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
205205
if has_parameters:
206206
# Fallback to try common temporal types
207207
try:

pandas/core/indexes/base.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,11 @@ def _engine(
810810
target_values = self._get_engine_target()
811811
if isinstance(target_values, ExtensionArray):
812812
if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)):
813-
return _masked_engines[target_values.dtype.name](target_values)
813+
try:
814+
return _masked_engines[target_values.dtype.name](target_values)
815+
except KeyError:
816+
# Not supported yet e.g. decimal
817+
pass
814818
elif self._engine_type is libindex.ObjectEngine:
815819
return libindex.ExtensionEngine(target_values)
816820

@@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike:
49484952
and not (
49494953
isinstance(self._values, ArrowExtensionArray)
49504954
and is_numeric_dtype(self.dtype)
4955+
# Exclude decimal
4956+
and self.dtype.kind != "O"
49514957
)
49524958
):
49534959
# TODO(ExtensionIndex): remove special-case, just use self._values

pandas/tests/extension/test_arrow.py

+137-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
time,
1717
timedelta,
1818
)
19+
from decimal import Decimal
1920
from io import (
2021
BytesIO,
2122
StringIO,
@@ -79,6 +80,14 @@ def data(dtype):
7980
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
8081
elif pa.types.is_unsigned_integer(pa_dtype):
8182
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
83+
elif pa.types.is_decimal(pa_dtype):
84+
data = (
85+
[Decimal("1"), Decimal("0.0")] * 4
86+
+ [None]
87+
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
88+
+ [None]
89+
+ [Decimal("0.5"), Decimal("33.123")]
90+
)
8291
elif pa.types.is_date(pa_dtype):
8392
data = (
8493
[date(2022, 1, 1), date(1999, 12, 31)] * 4
@@ -188,6 +197,10 @@ def data_for_grouping(dtype):
188197
A = b"a"
189198
B = b"b"
190199
C = b"c"
200+
elif pa.types.is_decimal(pa_dtype):
201+
A = Decimal("-1.1")
202+
B = Decimal("0.0")
203+
C = Decimal("1.1")
191204
else:
192205
raise NotImplementedError
193206
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
250263
class TestConstructors(base.BaseConstructorsTests):
251264
def test_from_dtype(self, data, request):
252265
pa_dtype = data.dtype.pyarrow_dtype
266+
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
267+
if pa.types.is_string(pa_dtype):
268+
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
269+
else:
270+
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
253271

254-
if pa.types.is_string(pa_dtype):
255-
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
256272
request.node.add_marker(
257273
pytest.mark.xfail(
258274
reason=reason,
259275
)
260276
)
261277
super().test_from_dtype(data)
262278

263-
def test_from_sequence_pa_array(self, data, request):
279+
def test_from_sequence_pa_array(self, data):
264280
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
265281
# data._data = pa.ChunkedArray
266282
result = type(data)._from_sequence(data._data)
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
285301
reason="Nanosecond time parsing not supported.",
286302
)
287303
)
288-
elif pa_version_under11p0 and pa.types.is_duration(pa_dtype):
304+
elif pa_version_under11p0 and (
305+
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
306+
):
289307
request.node.add_marker(
290308
pytest.mark.xfail(
291309
raises=pa.ArrowNotImplementedError,
@@ -384,7 +402,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
384402
# renders the exception messages even when not showing them
385403
pytest.skip(f"{all_numeric_accumulations} not implemented for pyarrow < 9")
386404

387-
elif all_numeric_accumulations == "cumsum" and pa.types.is_boolean(pa_type):
405+
elif all_numeric_accumulations == "cumsum" and (
406+
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
407+
):
388408
request.node.add_marker(
389409
pytest.mark.xfail(
390410
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
@@ -468,6 +488,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
468488
)
469489
if all_numeric_reductions in {"skew", "kurt"}:
470490
request.node.add_marker(xfail_mark)
491+
elif (
492+
all_numeric_reductions in {"var", "std", "median"}
493+
and pa_version_under7p0
494+
and pa.types.is_decimal(pa_dtype)
495+
):
496+
request.node.add_marker(xfail_mark)
471497
elif all_numeric_reductions == "sem" and pa_version_under8p0:
472498
request.node.add_marker(xfail_mark)
473499

@@ -590,8 +616,26 @@ def test_in_numeric_groupby(self, data_for_grouping):
590616

591617

592618
class TestBaseDtype(base.BaseDtypeTests):
619+
def test_check_dtype(self, data, request):
620+
pa_dtype = data.dtype.pyarrow_dtype
621+
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
622+
request.node.add_marker(
623+
pytest.mark.xfail(
624+
raises=ValueError,
625+
reason="decimal string repr affects numpy comparison",
626+
)
627+
)
628+
super().test_check_dtype(data)
629+
593630
def test_construct_from_string_own_name(self, dtype, request):
594631
pa_dtype = dtype.pyarrow_dtype
632+
if pa.types.is_decimal(pa_dtype):
633+
request.node.add_marker(
634+
pytest.mark.xfail(
635+
raises=NotImplementedError,
636+
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
637+
)
638+
)
595639

596640
if pa.types.is_string(pa_dtype):
597641
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
@@ -609,6 +653,13 @@ def test_is_dtype_from_name(self, dtype, request):
609653
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
610654
assert not type(dtype).is_dtype(dtype.name)
611655
else:
656+
if pa.types.is_decimal(pa_dtype):
657+
request.node.add_marker(
658+
pytest.mark.xfail(
659+
raises=NotImplementedError,
660+
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
661+
)
662+
)
612663
super().test_is_dtype_from_name(dtype)
613664

614665
def test_construct_from_string_another_type_raises(self, dtype):
@@ -627,6 +678,7 @@ def test_get_common_dtype(self, dtype, request):
627678
)
628679
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
629680
or pa.types.is_binary(pa_dtype)
681+
or pa.types.is_decimal(pa_dtype)
630682
):
631683
request.node.add_marker(
632684
pytest.mark.xfail(
@@ -700,6 +752,13 @@ def test_EA_types(self, engine, data, request):
700752
request.node.add_marker(
701753
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
702754
)
755+
elif pa.types.is_decimal(pa_dtype):
756+
request.node.add_marker(
757+
pytest.mark.xfail(
758+
raises=NotImplementedError,
759+
reason=f"Parameterized types {pa_dtype} not supported.",
760+
)
761+
)
703762
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
704763
request.node.add_marker(
705764
pytest.mark.xfail(
@@ -782,6 +841,13 @@ def test_argmin_argmax(
782841
reason=f"{pa_dtype} only has 2 unique possible values",
783842
)
784843
)
844+
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
845+
request.node.add_marker(
846+
pytest.mark.xfail(
847+
reason=f"No pyarrow kernel for {pa_dtype}",
848+
raises=pa.ArrowNotImplementedError,
849+
)
850+
)
785851
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
786852

787853
@pytest.mark.parametrize(
@@ -800,6 +866,14 @@ def test_argmin_argmax(
800866
def test_argreduce_series(
801867
self, data_missing_for_sorting, op_name, skipna, expected, request
802868
):
869+
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
870+
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
871+
request.node.add_marker(
872+
pytest.mark.xfail(
873+
reason=f"No pyarrow kernel for {pa_dtype}",
874+
raises=pa.ArrowNotImplementedError,
875+
)
876+
)
803877
super().test_argreduce_series(
804878
data_missing_for_sorting, op_name, skipna, expected
805879
)
@@ -888,6 +962,21 @@ def test_basic_equals(self, data):
888962
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
889963
divmod_exc = NotImplementedError
890964

965+
@classmethod
966+
def assert_equal(cls, left, right, **kwargs):
967+
if isinstance(left, pd.DataFrame):
968+
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
969+
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
970+
else:
971+
left_pa_type = left.dtype.pyarrow_dtype
972+
right_pa_type = right.dtype.pyarrow_dtype
973+
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
974+
# decimal precision can resize in the result type depending on data
975+
# just compare the float values
976+
left = left.astype("float[pyarrow]")
977+
right = right.astype("float[pyarrow]")
978+
tm.assert_equal(left, right, **kwargs)
979+
891980
def get_op_from_name(self, op_name):
892981
short_opname = op_name.strip("_")
893982
if short_opname == "rtruediv":
@@ -967,7 +1056,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
9671056
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
9681057
):
9691058
exc = None
970-
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
1059+
elif not (
1060+
pa.types.is_floating(pa_dtype)
1061+
or pa.types.is_integer(pa_dtype)
1062+
or pa.types.is_decimal(pa_dtype)
1063+
):
9711064
exc = pa.ArrowNotImplementedError
9721065
else:
9731066
exc = None
@@ -980,7 +1073,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
9801073

9811074
if (
9821075
opname == "__rpow__"
983-
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
1076+
and (
1077+
pa.types.is_floating(pa_dtype)
1078+
or pa.types.is_integer(pa_dtype)
1079+
or pa.types.is_decimal(pa_dtype)
1080+
)
9841081
and not pa_version_under7p0
9851082
):
9861083
mark = pytest.mark.xfail(
@@ -998,14 +1095,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
9981095
),
9991096
)
10001097
elif (
1001-
opname in {"__rfloordiv__"}
1002-
and pa.types.is_integer(pa_dtype)
1098+
opname == "__rfloordiv__"
1099+
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
10031100
and not pa_version_under7p0
10041101
):
10051102
mark = pytest.mark.xfail(
10061103
raises=pa.ArrowInvalid,
10071104
reason="divide by 0",
10081105
)
1106+
elif (
1107+
opname == "__rtruediv__"
1108+
and pa.types.is_decimal(pa_dtype)
1109+
and not pa_version_under7p0
1110+
):
1111+
mark = pytest.mark.xfail(
1112+
raises=pa.ArrowInvalid,
1113+
reason="divide by 0",
1114+
)
1115+
elif (
1116+
opname == "__pow__"
1117+
and pa.types.is_decimal(pa_dtype)
1118+
and pa_version_under7p0
1119+
):
1120+
mark = pytest.mark.xfail(
1121+
raises=pa.ArrowInvalid,
1122+
reason="Invalid decimal function: power_checked",
1123+
)
10091124

10101125
return mark
10111126

@@ -1226,6 +1341,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
12261341
expected = ArrowDtype(pa.timestamp("s", "UTC"))
12271342
assert dtype == expected
12281343

1344+
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
1345+
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")
1346+
12291347

12301348
@pytest.mark.parametrize(
12311349
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
@@ -1252,7 +1370,11 @@ def test_quantile(data, interpolation, quantile, request):
12521370
ser.quantile(q=quantile, interpolation=interpolation)
12531371
return
12541372

1255-
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1373+
if (
1374+
pa.types.is_integer(pa_dtype)
1375+
or pa.types.is_floating(pa_dtype)
1376+
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
1377+
):
12561378
pass
12571379
elif pa.types.is_temporal(data._data.type):
12581380
pass
@@ -1293,7 +1415,11 @@ def test_quantile(data, interpolation, quantile, request):
12931415
else:
12941416
# Just check the values
12951417
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
1296-
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1418+
if (
1419+
pa.types.is_integer(pa_dtype)
1420+
or pa.types.is_floating(pa_dtype)
1421+
or pa.types.is_decimal(pa_dtype)
1422+
):
12971423
expected = expected.astype("float64[pyarrow]")
12981424
result = result.astype("float64[pyarrow]")
12991425
tm.assert_series_equal(result, expected)

pandas/tests/indexes/test_common.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat):
449449
@pytest.mark.parametrize("na_position", [None, "middle"])
450450
def test_sort_values_invalid_na_position(index_with_missing, na_position):
451451
with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"):
452-
with tm.maybe_produces_warning(
453-
PerformanceWarning,
454-
getattr(index_with_missing.dtype, "storage", "") == "pyarrow",
455-
check_stacklevel=False,
456-
):
457-
index_with_missing.sort_values(na_position=na_position)
452+
index_with_missing.sort_values(na_position=na_position)
458453

459454

460455
@pytest.mark.parametrize("na_position", ["first", "last"])

0 commit comments

Comments
 (0)