Skip to content

TST: Test ArrowExtensionArray with decimal types #50964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
01f689e
TST: Test ArrowExtensionArray with decimal types
mroeschke Jan 24, 2023
e90cc24
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
5c5e28b
Version compat
mroeschke Jan 25, 2023
68e769b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
f799c74
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 25, 2023
c2f108d
Add other xfails based on min version
mroeschke Jan 25, 2023
b67a13e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
60ccab1
fix test
mroeschke Jan 26, 2023
e9a1b4e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
d63b1ea
fix typo
mroeschke Jan 26, 2023
c2a7610
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
f944e84
another typo
mroeschke Jan 26, 2023
7758f61
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 26, 2023
5492d11
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Jan 28, 2023
7f42ad3
only for skipna
mroeschke Jan 28, 2023
806abe9
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 1, 2023
a16ee3b
Add comment
mroeschke Feb 1, 2023
fe5800a
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 2, 2023
44f2eb8
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 2, 2023
129ac08
Fix
mroeschke Feb 2, 2023
13128ac
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 7, 2023
c44f260
undo comments
mroeschke Feb 7, 2023
f86666d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 7, 2023
c6f72d3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 8, 2023
406bb69
Bump version condition
mroeschke Feb 8, 2023
fb6992b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 9, 2023
47fc864
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 10, 2023
6a2923c
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
28006b2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
24e52e2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 17, 2023
dfe4bb0
Skip masked indexing engine for decimal
mroeschke Feb 17, 2023
798f9a3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 21, 2023
ff53f6d
Some merge stuff
mroeschke Feb 21, 2023
a4143df
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 22, 2023
6114391
Remove imaginary test
mroeschke Feb 22, 2023
4f8fadc
Fix condition
mroeschke Feb 22, 2023
deaa1db
Fix another test
mroeschke Feb 22, 2023
3ce8c9d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke Feb 22, 2023
142c771
Update condition
mroeschke Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

Expand All @@ -239,6 +240,7 @@
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ DECIMAL_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
pa.types.is_integer(pa_type)
or pa.types.is_floating(pa_type)
or pa.types.is_duration(pa_type)
or pa.types.is_decimal(pa_type)
):
# pyarrow only supports any/all for boolean dtype, we allow
# for other dtypes, matching our non-pyarrow behavior
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
try:
pa_dtype = pa.type_for_alias(base_type)
except ValueError as err:
has_parameters = re.search(r"\[.*\]", base_type)
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
if has_parameters:
# Fallback to try common temporal types
try:
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,11 @@ def _engine(
target_values = self._get_engine_target()
if isinstance(target_values, ExtensionArray):
if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)):
return _masked_engines[target_values.dtype.name](target_values)
try:
return _masked_engines[target_values.dtype.name](target_values)
except KeyError:
# Not supported yet e.g. decimal
pass
elif self._engine_type is libindex.ObjectEngine:
return libindex.ExtensionEngine(target_values)

Expand Down Expand Up @@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike:
and not (
isinstance(self._values, ArrowExtensionArray)
and is_numeric_dtype(self.dtype)
# Exclude decimal
and self.dtype.kind != "O"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we be more specific than this? e.g. time32 would also go through here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_numeric_dtype(self.dtype) should block time32. I could import pyarrow and check pa.types.is_decimal if you prefer`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, this is fine

)
):
# TODO(ExtensionIndex): remove special-case, just use self._values
Expand Down
148 changes: 137 additions & 11 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
time,
timedelta,
)
from decimal import Decimal
from io import (
BytesIO,
StringIO,
Expand Down Expand Up @@ -79,6 +80,14 @@ def data(dtype):
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
elif pa.types.is_unsigned_integer(pa_dtype):
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
elif pa.types.is_decimal(pa_dtype):
data = (
[Decimal("1"), Decimal("0.0")] * 4
+ [None]
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
+ [None]
+ [Decimal("0.5"), Decimal("33.123")]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like if you pass Decimal("nan") to pyarrow it raises. is that going to be relevant for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we coerce to NA

In [6]: pd.Series([decimal.Decimal("nan")], dtype=pd.ArrowDtype(pa.decimal128(4, 1)))
Out[6]:
0    <NA>
dtype: decimal128(4, 1)[pyarrow]

Which I guess is fuzzy with the NA vs nan debate

elif pa.types.is_date(pa_dtype):
data = (
[date(2022, 1, 1), date(1999, 12, 31)] * 4
Expand Down Expand Up @@ -188,6 +197,10 @@ def data_for_grouping(dtype):
A = b"a"
B = b"b"
C = b"c"
elif pa.types.is_decimal(pa_dtype):
A = Decimal("-1.1")
B = Decimal("0.0")
C = Decimal("1.1")
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
Expand Down Expand Up @@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
else:
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"

if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
request.node.add_marker(
pytest.mark.xfail(
reason=reason,
)
)
super().test_from_dtype(data)

def test_from_sequence_pa_array(self, data, request):
def test_from_sequence_pa_array(self, data):
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
# data._data = pa.ChunkedArray
result = type(data)._from_sequence(data._data)
Expand All @@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
reason="Nanosecond time parsing not supported.",
)
)
elif pa_version_under11p0 and pa.types.is_duration(pa_dtype):
elif pa_version_under11p0 and (
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down Expand Up @@ -384,7 +402,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
# renders the exception messages even when not showing them
pytest.skip(f"{all_numeric_accumulations} not implemented for pyarrow < 9")

elif all_numeric_accumulations == "cumsum" and pa.types.is_boolean(pa_type):
elif all_numeric_accumulations == "cumsum" and (
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
Expand Down Expand Up @@ -468,6 +488,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
)
if all_numeric_reductions in {"skew", "kurt"}:
request.node.add_marker(xfail_mark)
elif (
all_numeric_reductions in {"var", "std", "median"}
and pa_version_under7p0
and pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(xfail_mark)
elif all_numeric_reductions == "sem" and pa_version_under8p0:
request.node.add_marker(xfail_mark)

Expand Down Expand Up @@ -590,8 +616,26 @@ def test_in_numeric_groupby(self, data_for_grouping):


class TestBaseDtype(base.BaseDtypeTests):
def test_check_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
request.node.add_marker(
pytest.mark.xfail(
raises=ValueError,
reason="decimal string repr affects numpy comparison",
)
)
super().test_check_dtype(data)

def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)

if pa.types.is_string(pa_dtype):
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
Expand All @@ -609,6 +653,13 @@ def test_is_dtype_from_name(self, dtype, request):
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
assert not type(dtype).is_dtype(dtype.name)
else:
if pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
super().test_is_dtype_from_name(dtype)

def test_construct_from_string_another_type_raises(self, dtype):
Expand All @@ -627,6 +678,7 @@ def test_get_common_dtype(self, dtype, request):
)
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
or pa.types.is_binary(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -700,6 +752,13 @@ def test_EA_types(self, engine, data, request):
request.node.add_marker(
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
)
elif pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"Parameterized types {pa_dtype} not supported.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#50689 fixes this for timestamptz, after both this and that are merged we can see about fixing this for decimal.

)
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -782,6 +841,13 @@ def test_argmin_argmax(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize(
Expand All @@ -800,6 +866,14 @@ def test_argmin_argmax(
def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected, request
):
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)
Expand Down Expand Up @@ -888,6 +962,21 @@ def test_basic_equals(self, data):
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
divmod_exc = NotImplementedError

@classmethod
def assert_equal(cls, left, right, **kwargs):
if isinstance(left, pd.DataFrame):
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
else:
left_pa_type = left.dtype.pyarrow_dtype
right_pa_type = right.dtype.pyarrow_dtype
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
# decimal precision can resize in the result type depending on data
# just compare the float values
left = left.astype("float[pyarrow]")
right = right.astype("float[pyarrow]")
tm.assert_equal(left, right, **kwargs)

def get_op_from_name(self, op_name):
short_opname = op_name.strip("_")
if short_opname == "rtruediv":
Expand Down Expand Up @@ -967,7 +1056,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
elif not (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
exc = pa.ArrowNotImplementedError
else:
exc = None
Expand All @@ -980,7 +1073,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):

if (
opname == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
Expand All @@ -998,14 +1095,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
),
)
elif (
opname in {"__rfloordiv__"}
and pa.types.is_integer(pa_dtype)
opname == "__rfloordiv__"
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
elif (
opname == "__rtruediv__"
and pa.types.is_decimal(pa_dtype)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
elif (
opname == "__pow__"
and pa.types.is_decimal(pa_dtype)
and pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="Invalid decimal function: power_checked",
)

return mark

Expand Down Expand Up @@ -1226,6 +1341,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
expected = ArrowDtype(pa.timestamp("s", "UTC"))
assert dtype == expected

with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")


@pytest.mark.parametrize(
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
Expand All @@ -1252,7 +1370,11 @@ def test_quantile(data, interpolation, quantile, request):
ser.quantile(q=quantile, interpolation=interpolation)
return

if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
):
pass
elif pa.types.is_temporal(data._data.type):
pass
Expand Down Expand Up @@ -1293,7 +1415,11 @@ def test_quantile(data, interpolation, quantile, request):
else:
# Just check the values
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
expected = expected.astype("float64[pyarrow]")
result = result.astype("float64[pyarrow]")
tm.assert_series_equal(result, expected)
Expand Down
7 changes: 1 addition & 6 deletions pandas/tests/indexes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat):
@pytest.mark.parametrize("na_position", [None, "middle"])
def test_sort_values_invalid_na_position(index_with_missing, na_position):
with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"):
with tm.maybe_produces_warning(
PerformanceWarning,
getattr(index_with_missing.dtype, "storage", "") == "pyarrow",
check_stacklevel=False,
):
index_with_missing.sort_values(na_position=na_position)
index_with_missing.sort_values(na_position=na_position)


@pytest.mark.parametrize("na_position", ["first", "last"])
Expand Down