Skip to content

Auto Backport PR #50964 on branch 2.0.x (TST: Test ArrowExtensionArray with decimal types) #51562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

Expand All @@ -239,6 +240,7 @@
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ DECIMAL_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
pa.types.is_integer(pa_type)
or pa.types.is_floating(pa_type)
or pa.types.is_duration(pa_type)
or pa.types.is_decimal(pa_type)
):
# pyarrow only supports any/all for boolean dtype, we allow
# for other dtypes, matching our non-pyarrow behavior
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
try:
pa_dtype = pa.type_for_alias(base_type)
except ValueError as err:
has_parameters = re.search(r"\[.*\]", base_type)
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
if has_parameters:
# Fallback to try common temporal types
try:
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,11 @@ def _engine(
target_values = self._get_engine_target()
if isinstance(target_values, ExtensionArray):
if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)):
return _masked_engines[target_values.dtype.name](target_values)
try:
return _masked_engines[target_values.dtype.name](target_values)
except KeyError:
# Not supported yet e.g. decimal
pass
elif self._engine_type is libindex.ObjectEngine:
return libindex.ExtensionEngine(target_values)

Expand Down Expand Up @@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike:
and not (
isinstance(self._values, ArrowExtensionArray)
and is_numeric_dtype(self.dtype)
# Exclude decimal
and self.dtype.kind != "O"
)
):
# TODO(ExtensionIndex): remove special-case, just use self._values
Expand Down
148 changes: 137 additions & 11 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
time,
timedelta,
)
from decimal import Decimal
from io import (
BytesIO,
StringIO,
Expand Down Expand Up @@ -79,6 +80,14 @@ def data(dtype):
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
elif pa.types.is_unsigned_integer(pa_dtype):
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
elif pa.types.is_decimal(pa_dtype):
data = (
[Decimal("1"), Decimal("0.0")] * 4
+ [None]
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
+ [None]
+ [Decimal("0.5"), Decimal("33.123")]
)
elif pa.types.is_date(pa_dtype):
data = (
[date(2022, 1, 1), date(1999, 12, 31)] * 4
Expand Down Expand Up @@ -188,6 +197,10 @@ def data_for_grouping(dtype):
A = b"a"
B = b"b"
C = b"c"
elif pa.types.is_decimal(pa_dtype):
A = Decimal("-1.1")
B = Decimal("0.0")
C = Decimal("1.1")
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
Expand Down Expand Up @@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
else:
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"

if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
request.node.add_marker(
pytest.mark.xfail(
reason=reason,
)
)
super().test_from_dtype(data)

def test_from_sequence_pa_array(self, data, request):
def test_from_sequence_pa_array(self, data):
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
# data._data = pa.ChunkedArray
result = type(data)._from_sequence(data._data)
Expand All @@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
reason="Nanosecond time parsing not supported.",
)
)
elif pa_version_under11p0 and pa.types.is_duration(pa_dtype):
elif pa_version_under11p0 and (
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down Expand Up @@ -392,7 +410,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
raises=NotImplementedError,
)
)
elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)):
elif all_numeric_accumulations == "cumsum" and (
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
Expand Down Expand Up @@ -476,6 +496,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
)
if all_numeric_reductions in {"skew", "kurt"}:
request.node.add_marker(xfail_mark)
elif (
all_numeric_reductions in {"var", "std", "median"}
and pa_version_under7p0
and pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(xfail_mark)
elif all_numeric_reductions == "sem" and pa_version_under8p0:
request.node.add_marker(xfail_mark)

Expand Down Expand Up @@ -598,8 +624,26 @@ def test_in_numeric_groupby(self, data_for_grouping):


class TestBaseDtype(base.BaseDtypeTests):
def test_check_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
request.node.add_marker(
pytest.mark.xfail(
raises=ValueError,
reason="decimal string repr affects numpy comparison",
)
)
super().test_check_dtype(data)

def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)

if pa.types.is_string(pa_dtype):
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
Expand All @@ -617,6 +661,13 @@ def test_is_dtype_from_name(self, dtype, request):
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
assert not type(dtype).is_dtype(dtype.name)
else:
if pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
super().test_is_dtype_from_name(dtype)

def test_construct_from_string_another_type_raises(self, dtype):
Expand All @@ -635,6 +686,7 @@ def test_get_common_dtype(self, dtype, request):
)
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
or pa.types.is_binary(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -708,6 +760,13 @@ def test_EA_types(self, engine, data, request):
request.node.add_marker(
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
)
elif pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"Parameterized types {pa_dtype} not supported.",
)
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -790,6 +849,13 @@ def test_argmin_argmax(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize(
Expand All @@ -808,6 +874,14 @@ def test_argmin_argmax(
def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected, request
):
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
raises=pa.ArrowNotImplementedError,
)
)
super().test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)
Expand Down Expand Up @@ -906,6 +980,21 @@ def test_basic_equals(self, data):
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
divmod_exc = NotImplementedError

@classmethod
def assert_equal(cls, left, right, **kwargs):
if isinstance(left, pd.DataFrame):
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
else:
left_pa_type = left.dtype.pyarrow_dtype
right_pa_type = right.dtype.pyarrow_dtype
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
# decimal precision can resize in the result type depending on data
# just compare the float values
left = left.astype("float[pyarrow]")
right = right.astype("float[pyarrow]")
tm.assert_equal(left, right, **kwargs)

def get_op_from_name(self, op_name):
short_opname = op_name.strip("_")
if short_opname == "rtruediv":
Expand Down Expand Up @@ -975,7 +1064,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
elif not (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
exc = pa.ArrowNotImplementedError
else:
exc = None
Expand All @@ -988,7 +1081,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):

if (
opname == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or pa.types.is_decimal(pa_dtype)
)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
Expand All @@ -1006,14 +1103,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
),
)
elif (
opname in {"__rfloordiv__"}
and pa.types.is_integer(pa_dtype)
opname == "__rfloordiv__"
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
elif (
opname == "__rtruediv__"
and pa.types.is_decimal(pa_dtype)
and not pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
elif (
opname == "__pow__"
and pa.types.is_decimal(pa_dtype)
and pa_version_under7p0
):
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="Invalid decimal function: power_checked",
)

return mark

Expand Down Expand Up @@ -1231,6 +1346,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
expected = ArrowDtype(pa.timestamp("s", "UTC"))
assert dtype == expected

with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")


@pytest.mark.parametrize(
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
Expand All @@ -1257,7 +1375,11 @@ def test_quantile(data, interpolation, quantile, request):
ser.quantile(q=quantile, interpolation=interpolation)
return

if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
):
pass
elif pa.types.is_temporal(data._data.type):
pass
Expand Down Expand Up @@ -1298,7 +1420,11 @@ def test_quantile(data, interpolation, quantile, request):
else:
# Just check the values
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
if (
pa.types.is_integer(pa_dtype)
or pa.types.is_floating(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
expected = expected.astype("float64[pyarrow]")
result = result.astype("float64[pyarrow]")
tm.assert_series_equal(result, expected)
Expand Down
7 changes: 1 addition & 6 deletions pandas/tests/indexes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat):
@pytest.mark.parametrize("na_position", [None, "middle"])
def test_sort_values_invalid_na_position(index_with_missing, na_position):
with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"):
with tm.maybe_produces_warning(
PerformanceWarning,
getattr(index_with_missing.dtype, "storage", "") == "pyarrow",
check_stacklevel=False,
):
index_with_missing.sort_values(na_position=na_position)
index_with_missing.sort_values(na_position=na_position)


@pytest.mark.parametrize("na_position", ["first", "last"])
Expand Down