Skip to content

Commit 2230bad

Browse files
authored
Auto Backport PR #50964 on branch 2.0.x (TST: Test ArrowExtensionArray with decimal types) (#51562)
1 parent db1bd72 commit 2230bad

File tree

6 files changed

+149
-19
lines changed

6 files changed

+149
-19
lines changed

pandas/_testing/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
FLOAT_PYARROW_DTYPES_STR_REPR = [
216216
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
217217
]
218+
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
218219
STRING_PYARROW_DTYPES = [pa.string()]
219220
BINARY_PYARROW_DTYPES = [pa.binary()]
220221

@@ -239,6 +240,7 @@
239240
ALL_PYARROW_DTYPES = (
240241
ALL_INT_PYARROW_DTYPES
241242
+ FLOAT_PYARROW_DTYPES
243+
+ DECIMAL_PYARROW_DTYPES
242244
+ STRING_PYARROW_DTYPES
243245
+ BINARY_PYARROW_DTYPES
244246
+ TIME_PYARROW_DTYPES

pandas/core/arrays/arrow/array.py

+1
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
10981098
pa.types.is_integer(pa_type)
10991099
or pa.types.is_floating(pa_type)
11001100
or pa.types.is_duration(pa_type)
1101+
or pa.types.is_decimal(pa_type)
11011102
):
11021103
# pyarrow only supports any/all for boolean dtype, we allow
11031104
# for other dtypes, matching our non-pyarrow behavior

pandas/core/arrays/arrow/dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
201201
try:
202202
pa_dtype = pa.type_for_alias(base_type)
203203
except ValueError as err:
204-
has_parameters = re.search(r"\[.*\]", base_type)
204+
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
205205
if has_parameters:
206206
# Fallback to try common temporal types
207207
try:

pandas/core/indexes/base.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,11 @@ def _engine(
810810
target_values = self._get_engine_target()
811811
if isinstance(target_values, ExtensionArray):
812812
if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)):
813-
return _masked_engines[target_values.dtype.name](target_values)
813+
try:
814+
return _masked_engines[target_values.dtype.name](target_values)
815+
except KeyError:
816+
# Not supported yet e.g. decimal
817+
pass
814818
elif self._engine_type is libindex.ObjectEngine:
815819
return libindex.ExtensionEngine(target_values)
816820

@@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike:
49484952
and not (
49494953
isinstance(self._values, ArrowExtensionArray)
49504954
and is_numeric_dtype(self.dtype)
4955+
# Exclude decimal
4956+
and self.dtype.kind != "O"
49514957
)
49524958
):
49534959
# TODO(ExtensionIndex): remove special-case, just use self._values

pandas/tests/extension/test_arrow.py

+137-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
time,
1717
timedelta,
1818
)
19+
from decimal import Decimal
1920
from io import (
2021
BytesIO,
2122
StringIO,
@@ -79,6 +80,14 @@ def data(dtype):
7980
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
8081
elif pa.types.is_unsigned_integer(pa_dtype):
8182
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
83+
elif pa.types.is_decimal(pa_dtype):
84+
data = (
85+
[Decimal("1"), Decimal("0.0")] * 4
86+
+ [None]
87+
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
88+
+ [None]
89+
+ [Decimal("0.5"), Decimal("33.123")]
90+
)
8291
elif pa.types.is_date(pa_dtype):
8392
data = (
8493
[date(2022, 1, 1), date(1999, 12, 31)] * 4
@@ -188,6 +197,10 @@ def data_for_grouping(dtype):
188197
A = b"a"
189198
B = b"b"
190199
C = b"c"
200+
elif pa.types.is_decimal(pa_dtype):
201+
A = Decimal("-1.1")
202+
B = Decimal("0.0")
203+
C = Decimal("1.1")
191204
else:
192205
raise NotImplementedError
193206
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
250263
class TestConstructors(base.BaseConstructorsTests):
251264
def test_from_dtype(self, data, request):
252265
pa_dtype = data.dtype.pyarrow_dtype
266+
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
267+
if pa.types.is_string(pa_dtype):
268+
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
269+
else:
270+
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
253271

254-
if pa.types.is_string(pa_dtype):
255-
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
256272
request.node.add_marker(
257273
pytest.mark.xfail(
258274
reason=reason,
259275
)
260276
)
261277
super().test_from_dtype(data)
262278

263-
def test_from_sequence_pa_array(self, data, request):
279+
def test_from_sequence_pa_array(self, data):
264280
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
265281
# data._data = pa.ChunkedArray
266282
result = type(data)._from_sequence(data._data)
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
285301
reason="Nanosecond time parsing not supported.",
286302
)
287303
)
288-
elif pa_version_under11p0 and pa.types.is_duration(pa_dtype):
304+
elif pa_version_under11p0 and (
305+
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
306+
):
289307
request.node.add_marker(
290308
pytest.mark.xfail(
291309
raises=pa.ArrowNotImplementedError,
@@ -392,7 +410,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
392410
raises=NotImplementedError,
393411
)
394412
)
395-
elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)):
413+
elif all_numeric_accumulations == "cumsum" and (
414+
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
415+
):
396416
request.node.add_marker(
397417
pytest.mark.xfail(
398418
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
@@ -476,6 +496,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
476496
)
477497
if all_numeric_reductions in {"skew", "kurt"}:
478498
request.node.add_marker(xfail_mark)
499+
elif (
500+
all_numeric_reductions in {"var", "std", "median"}
501+
and pa_version_under7p0
502+
and pa.types.is_decimal(pa_dtype)
503+
):
504+
request.node.add_marker(xfail_mark)
479505
elif all_numeric_reductions == "sem" and pa_version_under8p0:
480506
request.node.add_marker(xfail_mark)
481507

@@ -598,8 +624,26 @@ def test_in_numeric_groupby(self, data_for_grouping):
598624

599625

600626
class TestBaseDtype(base.BaseDtypeTests):
627+
def test_check_dtype(self, data, request):
628+
pa_dtype = data.dtype.pyarrow_dtype
629+
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
630+
request.node.add_marker(
631+
pytest.mark.xfail(
632+
raises=ValueError,
633+
reason="decimal string repr affects numpy comparison",
634+
)
635+
)
636+
super().test_check_dtype(data)
637+
601638
def test_construct_from_string_own_name(self, dtype, request):
602639
pa_dtype = dtype.pyarrow_dtype
640+
if pa.types.is_decimal(pa_dtype):
641+
request.node.add_marker(
642+
pytest.mark.xfail(
643+
raises=NotImplementedError,
644+
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
645+
)
646+
)
603647

604648
if pa.types.is_string(pa_dtype):
605649
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
@@ -617,6 +661,13 @@ def test_is_dtype_from_name(self, dtype, request):
617661
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
618662
assert not type(dtype).is_dtype(dtype.name)
619663
else:
664+
if pa.types.is_decimal(pa_dtype):
665+
request.node.add_marker(
666+
pytest.mark.xfail(
667+
raises=NotImplementedError,
668+
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
669+
)
670+
)
620671
super().test_is_dtype_from_name(dtype)
621672

622673
def test_construct_from_string_another_type_raises(self, dtype):
@@ -635,6 +686,7 @@ def test_get_common_dtype(self, dtype, request):
635686
)
636687
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
637688
or pa.types.is_binary(pa_dtype)
689+
or pa.types.is_decimal(pa_dtype)
638690
):
639691
request.node.add_marker(
640692
pytest.mark.xfail(
@@ -708,6 +760,13 @@ def test_EA_types(self, engine, data, request):
708760
request.node.add_marker(
709761
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
710762
)
763+
elif pa.types.is_decimal(pa_dtype):
764+
request.node.add_marker(
765+
pytest.mark.xfail(
766+
raises=NotImplementedError,
767+
reason=f"Parameterized types {pa_dtype} not supported.",
768+
)
769+
)
711770
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
712771
request.node.add_marker(
713772
pytest.mark.xfail(
@@ -790,6 +849,13 @@ def test_argmin_argmax(
790849
reason=f"{pa_dtype} only has 2 unique possible values",
791850
)
792851
)
852+
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
853+
request.node.add_marker(
854+
pytest.mark.xfail(
855+
reason=f"No pyarrow kernel for {pa_dtype}",
856+
raises=pa.ArrowNotImplementedError,
857+
)
858+
)
793859
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
794860

795861
@pytest.mark.parametrize(
@@ -808,6 +874,14 @@ def test_argmin_argmax(
808874
def test_argreduce_series(
809875
self, data_missing_for_sorting, op_name, skipna, expected, request
810876
):
877+
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
878+
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
879+
request.node.add_marker(
880+
pytest.mark.xfail(
881+
reason=f"No pyarrow kernel for {pa_dtype}",
882+
raises=pa.ArrowNotImplementedError,
883+
)
884+
)
811885
super().test_argreduce_series(
812886
data_missing_for_sorting, op_name, skipna, expected
813887
)
@@ -906,6 +980,21 @@ def test_basic_equals(self, data):
906980
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
907981
divmod_exc = NotImplementedError
908982

983+
@classmethod
984+
def assert_equal(cls, left, right, **kwargs):
985+
if isinstance(left, pd.DataFrame):
986+
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
987+
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
988+
else:
989+
left_pa_type = left.dtype.pyarrow_dtype
990+
right_pa_type = right.dtype.pyarrow_dtype
991+
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
992+
# decimal precision can resize in the result type depending on data
993+
# just compare the float values
994+
left = left.astype("float[pyarrow]")
995+
right = right.astype("float[pyarrow]")
996+
tm.assert_equal(left, right, **kwargs)
997+
909998
def get_op_from_name(self, op_name):
910999
short_opname = op_name.strip("_")
9111000
if short_opname == "rtruediv":
@@ -975,7 +1064,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
9751064
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
9761065
):
9771066
exc = None
978-
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
1067+
elif not (
1068+
pa.types.is_floating(pa_dtype)
1069+
or pa.types.is_integer(pa_dtype)
1070+
or pa.types.is_decimal(pa_dtype)
1071+
):
9791072
exc = pa.ArrowNotImplementedError
9801073
else:
9811074
exc = None
@@ -988,7 +1081,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
9881081

9891082
if (
9901083
opname == "__rpow__"
991-
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
1084+
and (
1085+
pa.types.is_floating(pa_dtype)
1086+
or pa.types.is_integer(pa_dtype)
1087+
or pa.types.is_decimal(pa_dtype)
1088+
)
9921089
and not pa_version_under7p0
9931090
):
9941091
mark = pytest.mark.xfail(
@@ -1006,14 +1103,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
10061103
),
10071104
)
10081105
elif (
1009-
opname in {"__rfloordiv__"}
1010-
and pa.types.is_integer(pa_dtype)
1106+
opname == "__rfloordiv__"
1107+
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
10111108
and not pa_version_under7p0
10121109
):
10131110
mark = pytest.mark.xfail(
10141111
raises=pa.ArrowInvalid,
10151112
reason="divide by 0",
10161113
)
1114+
elif (
1115+
opname == "__rtruediv__"
1116+
and pa.types.is_decimal(pa_dtype)
1117+
and not pa_version_under7p0
1118+
):
1119+
mark = pytest.mark.xfail(
1120+
raises=pa.ArrowInvalid,
1121+
reason="divide by 0",
1122+
)
1123+
elif (
1124+
opname == "__pow__"
1125+
and pa.types.is_decimal(pa_dtype)
1126+
and pa_version_under7p0
1127+
):
1128+
mark = pytest.mark.xfail(
1129+
raises=pa.ArrowInvalid,
1130+
reason="Invalid decimal function: power_checked",
1131+
)
10171132

10181133
return mark
10191134

@@ -1231,6 +1346,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
12311346
expected = ArrowDtype(pa.timestamp("s", "UTC"))
12321347
assert dtype == expected
12331348

1349+
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
1350+
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")
1351+
12341352

12351353
@pytest.mark.parametrize(
12361354
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
@@ -1257,7 +1375,11 @@ def test_quantile(data, interpolation, quantile, request):
12571375
ser.quantile(q=quantile, interpolation=interpolation)
12581376
return
12591377

1260-
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1378+
if (
1379+
pa.types.is_integer(pa_dtype)
1380+
or pa.types.is_floating(pa_dtype)
1381+
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
1382+
):
12611383
pass
12621384
elif pa.types.is_temporal(data._data.type):
12631385
pass
@@ -1298,7 +1420,11 @@ def test_quantile(data, interpolation, quantile, request):
12981420
else:
12991421
# Just check the values
13001422
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
1301-
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1423+
if (
1424+
pa.types.is_integer(pa_dtype)
1425+
or pa.types.is_floating(pa_dtype)
1426+
or pa.types.is_decimal(pa_dtype)
1427+
):
13021428
expected = expected.astype("float64[pyarrow]")
13031429
result = result.astype("float64[pyarrow]")
13041430
tm.assert_series_equal(result, expected)

pandas/tests/indexes/test_common.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat):
449449
@pytest.mark.parametrize("na_position", [None, "middle"])
450450
def test_sort_values_invalid_na_position(index_with_missing, na_position):
451451
with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"):
452-
with tm.maybe_produces_warning(
453-
PerformanceWarning,
454-
getattr(index_with_missing.dtype, "storage", "") == "pyarrow",
455-
check_stacklevel=False,
456-
):
457-
index_with_missing.sort_values(na_position=na_position)
452+
index_with_missing.sort_values(na_position=na_position)
458453

459454

460455
@pytest.mark.parametrize("na_position", ["first", "last"])

0 commit comments

Comments
 (0)