Skip to content

TST: Test ArrowExtensionArray with decimal types #50964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
01f689e
TST: Test ArrowExtensionArray with decimal types
mroeschke Jan 24, 2023
e90cc24
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
5c5e28b
Version compat
mroeschke Jan 25, 2023
68e769b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
f799c74
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
c2f108d
Add other xfails based on min version
mroeschke Jan 25, 2023
b67a13e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
60ccab1
fix test
mroeschke Jan 26, 2023
e9a1b4e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
d63b1ea
fix typo
mroeschke Jan 26, 2023
c2a7610
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
f944e84
another typo
mroeschke Jan 26, 2023
7758f61
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
5492d11
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 28, 2023
7f42ad3
only for skipna
mroeschke Jan 28, 2023
806abe9
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 1, 2023
a16ee3b
Add comment
mroeschke Feb 1, 2023
fe5800a
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 2, 2023
44f2eb8
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 2, 2023
129ac08
Fix
mroeschke Feb 2, 2023
13128ac
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 7, 2023
c44f260
undo comments
mroeschke Feb 7, 2023
f86666d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 7, 2023
c6f72d3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 8, 2023
406bb69
Bump version condition
mroeschke Feb 8, 2023
fb6992b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 9, 2023
47fc864
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 10, 2023
6a2923c
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
28006b2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
24e52e2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
dfe4bb0
Skip masked indexing engine for decimal
mroeschke Feb 17, 2023
798f9a3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 21, 2023
ff53f6d
Some merge stuff
mroeschke Feb 21, 2023
a4143df
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 22, 2023
6114391
Remove imaginary test
mroeschke Feb 22, 2023
4f8fadc
Fix condition
mroeschke Feb 22, 2023
deaa1db
Fix another test
mroeschke Feb 22, 2023
3ce8c9d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 22, 2023
142c771
Update condition
mroeschke Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

Expand All @@ -237,6 +238,7 @@
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ DECIMAL_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
pa.types.is_integer(pa_type)
or pa.types.is_floating(pa_type)
or pa.types.is_duration(pa_type)
or pa.types.is_decimal(pa_type)
):
# pyarrow only supports any/all for boolean dtype, we allow
# for other dtypes, matching our non-pyarrow behavior
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
try:
pa_dtype = pa.type_for_alias(base_type)
except ValueError as err:
has_parameters = re.search(r"\[.*\]", base_type)
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
if has_parameters:
raise NotImplementedError(
"Passing pyarrow type specific parameters "
Expand Down
144 changes: 129 additions & 15 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
time,
timedelta,
)
from decimal import Decimal
from io import (
BytesIO,
StringIO,
Expand Down Expand Up @@ -77,6 +78,14 @@ def data(dtype):
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
elif pa.types.is_unsigned_integer(pa_dtype):
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
elif pa.types.is_decimal(pa_dtype):
data = (
[Decimal("1"), Decimal("0.0")] * 4
+ [None]
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
+ [None]
+ [Decimal("0.5"), Decimal("33.123")]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like if you pass Decimal("nan") to pyarrow it raises. is that going to be relevant for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we coerce to NA

In [6]: pd.Series([decimal.Decimal("nan")], dtype=pd.ArrowDtype(pa.decimal128(4, 1)))
Out[6]:
0    <NA>
dtype: decimal128(4, 1)[pyarrow]

Which I guess is fuzzy with the NA vs nan debate

elif pa.types.is_date(pa_dtype):
data = (
[date(2022, 1, 1), date(1999, 12, 31)] * 4
Expand Down Expand Up @@ -186,6 +195,10 @@ def data_for_grouping(dtype):
A = b"a"
B = b"b"
C = b"c"
elif pa.types.is_decimal(pa_dtype):
A = Decimal("-1.1")
B = Decimal("0.0")
C = Decimal("1.1")
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
Expand Down Expand Up @@ -248,8 +261,10 @@ def test_astype_str(self, data, request):
class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string(
pa_dtype
if (
(pa.types.is_timestamp(pa_dtype) and pa_dtype.tz)
or pa.types.is_string(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
Expand All @@ -262,7 +277,7 @@ def test_from_dtype(self, data, request):
)
super().test_from_dtype(data)

def test_from_sequence_pa_array(self, data, request):
def test_from_sequence_pa_array(self, data):
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
# data._data = pa.ChunkedArray
result = type(data)._from_sequence(data._data)
Expand All @@ -287,7 +302,7 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
reason="Nanosecond time parsing not supported.",
)
)
elif pa.types.is_duration(pa_dtype):
elif pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down Expand Up @@ -347,6 +362,8 @@ def test_getitem_scalar(self, data):
exp_type = date
elif pa.types.is_time(pa_type):
exp_type = time
elif pa.types.is_decimal(pa_type):
exp_type = Decimal
else:
raise NotImplementedError(data.dtype)

Expand Down Expand Up @@ -426,7 +443,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
raises=NotImplementedError,
)
)
elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)):
elif all_numeric_accumulations == "cumsum" and (
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
Expand Down Expand Up @@ -510,6 +529,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
)
if all_numeric_reductions in {"skew", "kurt"}:
request.node.add_marker(xfail_mark)
elif (
all_numeric_reductions in {"var", "std", "median"}
and pa_version_under7p0
and pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(xfail_mark)
elif all_numeric_reductions == "sem" and pa_version_under8p0:
request.node.add_marker(xfail_mark)

Expand Down Expand Up @@ -632,9 +657,22 @@ def test_in_numeric_groupby(self, data_for_grouping):


class TestBaseDtype(base.BaseDtypeTests):
def test_check_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
request.node.add_marker(
pytest.mark.xfail(
raises=ValueError,
reason="decimal string repr affects numpy comparison",
)
)
super().test_check_dtype(data)

def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None
) or pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
Expand All @@ -655,7 +693,9 @@ def test_construct_from_string_own_name(self, dtype, request):

def test_is_dtype_from_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None
) or pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
Expand All @@ -675,7 +715,9 @@ def test_is_dtype_from_name(self, dtype, request):

def test_construct_from_string(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
if (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None
) or pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
Expand Down Expand Up @@ -710,6 +752,7 @@ def test_get_common_dtype(self, dtype, request):
)
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
or pa.types.is_binary(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -783,11 +826,13 @@ def test_EA_types(self, engine, data, request):
request.node.add_marker(
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
elif (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None
) or pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
reason=f"Parameterized types {pa_dtype} not supported.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#50689 fixes this for timestamptz, after both this and that are merged we can see about fixing this for decimal.

)
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
Expand Down Expand Up @@ -872,6 +917,13 @@ def test_argmin_argmax(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize(
Expand All @@ -890,6 +942,14 @@ def test_argmin_argmax(
def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected, request
):
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)
Expand Down Expand Up @@ -989,6 +1049,21 @@ class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):

divmod_exc = NotImplementedError

@classmethod
def assert_equal(cls, left, right, **kwargs):
if isinstance(left, pd.DataFrame):
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
else:
left_pa_type = left.dtype.pyarrow_dtype
right_pa_type = right.dtype.pyarrow_dtype
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
# decimal precision can resize in the result type depending on data
# just compare the float values
left = left.astype("float[pyarrow]")
right = right.astype("float[pyarrow]")
tm.assert_equal(left, right, **kwargs)

def _patch_combine(self, obj, other, op):
# BaseOpsUtil._combine can upcast expected dtype
# (because it generates expected on python scalars)
Expand Down Expand Up @@ -1044,7 +1119,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
exc = NotImplementedError
elif arrow_temporal_supported:
exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
elif not (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
exc = pa.ArrowNotImplementedError
else:
exc = None
Expand All @@ -1057,7 +1136,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):

if (
opname == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
Expand All @@ -1076,13 +1159,26 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
)
elif (
opname in {"__rtruediv__", "__rfloordiv__"}
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
elif (
opname == "__pow__"
and pa.types.is_decimal(pa_dtype)
and pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="Invalid decimal function: power_checked",
)

return mark

Expand Down Expand Up @@ -1297,6 +1393,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")

with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")


@pytest.mark.parametrize(
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
Expand All @@ -1323,7 +1422,11 @@ def test_quantile(data, interpolation, quantile, request):
ser.quantile(q=quantile, interpolation=interpolation)
return

if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
):
pass
elif pa.types.is_temporal(data._data.type):
pass
Expand Down Expand Up @@ -1364,7 +1467,11 @@ def test_quantile(data, interpolation, quantile, request):
else:
# Just check the values
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
expected = expected.astype("float64[pyarrow]")
result = result.astype("float64[pyarrow]")
tm.assert_series_equal(result, expected)
Expand All @@ -1385,6 +1492,13 @@ def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
reason=f"mode not supported by pyarrow for {pa_dtype}",
)
)
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"mode not supported by pyarrow for {pa_dtype}",
)
)
elif (
pa.types.is_boolean(pa_dtype)
and "multi_mode" in request.node.nodeid
Expand Down