-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
TST: Test ArrowExtensionArray with decimal types #50964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
01f689e
e90cc24
5c5e28b
68e769b
f799c74
c2f108d
b67a13e
60ccab1
e9a1b4e
d63b1ea
c2a7610
f944e84
7758f61
5492d11
7f42ad3
806abe9
a16ee3b
fe5800a
44f2eb8
129ac08
13128ac
c44f260
f86666d
c6f72d3
406bb69
fb6992b
47fc864
6a2923c
28006b2
24e52e2
dfe4bb0
798f9a3
ff53f6d
a4143df
6114391
4f8fadc
deaa1db
3ce8c9d
142c771
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
time, | ||
timedelta, | ||
) | ||
from decimal import Decimal | ||
from io import ( | ||
BytesIO, | ||
StringIO, | ||
|
@@ -79,6 +80,14 @@ def data(dtype): | |
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99] | ||
elif pa.types.is_unsigned_integer(pa_dtype): | ||
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99] | ||
elif pa.types.is_decimal(pa_dtype): | ||
data = ( | ||
[Decimal("1"), Decimal("0.0")] * 4 | ||
+ [None] | ||
+ [Decimal("-2.0"), Decimal("-1.0")] * 44 | ||
+ [None] | ||
+ [Decimal("0.5"), Decimal("33.123")] | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like if you pass Decimal("nan") to pyarrow it raises. is that going to be relevant for us? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we coerce to NA
Which I guess is fuzzy with the NA vs nan debate |
||
elif pa.types.is_date(pa_dtype): | ||
data = ( | ||
[date(2022, 1, 1), date(1999, 12, 31)] * 4 | ||
|
@@ -188,6 +197,10 @@ def data_for_grouping(dtype): | |
A = b"a" | ||
B = b"b" | ||
C = b"c" | ||
elif pa.types.is_decimal(pa_dtype): | ||
A = Decimal("-1.1") | ||
B = Decimal("0.0") | ||
C = Decimal("1.1") | ||
else: | ||
raise NotImplementedError | ||
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype) | ||
|
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request): | |
class TestConstructors(base.BaseConstructorsTests): | ||
def test_from_dtype(self, data, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype): | ||
if pa.types.is_string(pa_dtype): | ||
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" | ||
else: | ||
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}" | ||
|
||
if pa.types.is_string(pa_dtype): | ||
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=reason, | ||
) | ||
) | ||
super().test_from_dtype(data) | ||
|
||
def test_from_sequence_pa_array(self, data, request): | ||
def test_from_sequence_pa_array(self, data): | ||
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 | ||
# data._data = pa.ChunkedArray | ||
result = type(data)._from_sequence(data._data) | ||
|
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request): | |
reason="Nanosecond time parsing not supported.", | ||
) | ||
) | ||
elif pa_version_under11p0 and pa.types.is_duration(pa_dtype): | ||
elif pa_version_under11p0 and ( | ||
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype) | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowNotImplementedError, | ||
|
@@ -384,7 +402,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques | |
# renders the exception messages even when not showing them | ||
pytest.skip(f"{all_numeric_accumulations} not implemented for pyarrow < 9") | ||
|
||
elif all_numeric_accumulations == "cumsum" and pa.types.is_boolean(pa_type): | ||
elif all_numeric_accumulations == "cumsum" and ( | ||
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type) | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"{all_numeric_accumulations} not implemented for {pa_type}", | ||
|
@@ -468,6 +488,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): | |
) | ||
if all_numeric_reductions in {"skew", "kurt"}: | ||
request.node.add_marker(xfail_mark) | ||
elif ( | ||
all_numeric_reductions in {"var", "std", "median"} | ||
and pa_version_under7p0 | ||
and pa.types.is_decimal(pa_dtype) | ||
): | ||
request.node.add_marker(xfail_mark) | ||
elif all_numeric_reductions == "sem" and pa_version_under8p0: | ||
request.node.add_marker(xfail_mark) | ||
|
||
|
@@ -590,8 +616,26 @@ def test_in_numeric_groupby(self, data_for_grouping): | |
|
||
|
||
class TestBaseDtype(base.BaseDtypeTests): | ||
def test_check_dtype(self, data, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=ValueError, | ||
reason="decimal string repr affects numpy comparison", | ||
) | ||
) | ||
super().test_check_dtype(data) | ||
|
||
def test_construct_from_string_own_name(self, dtype, request): | ||
pa_dtype = dtype.pyarrow_dtype | ||
if pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", | ||
) | ||
) | ||
|
||
if pa.types.is_string(pa_dtype): | ||
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) | ||
|
@@ -609,6 +653,13 @@ def test_is_dtype_from_name(self, dtype, request): | |
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) | ||
assert not type(dtype).is_dtype(dtype.name) | ||
else: | ||
if pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", | ||
) | ||
) | ||
super().test_is_dtype_from_name(dtype) | ||
|
||
def test_construct_from_string_another_type_raises(self, dtype): | ||
|
@@ -627,6 +678,7 @@ def test_get_common_dtype(self, dtype, request): | |
) | ||
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns") | ||
or pa.types.is_binary(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
|
@@ -700,6 +752,13 @@ def test_EA_types(self, engine, data, request): | |
request.node.add_marker( | ||
pytest.mark.xfail(raises=TypeError, reason="GH 47534") | ||
) | ||
elif pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
reason=f"Parameterized types {pa_dtype} not supported.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #50689 fixes this for timestamptz, after both this and that are merged we can see about fixing this for decimal. |
||
) | ||
) | ||
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
|
@@ -782,6 +841,13 @@ def test_argmin_argmax( | |
reason=f"{pa_dtype} only has 2 unique possible values", | ||
) | ||
) | ||
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"No pyarrow kernel for {pa_dtype}", | ||
raises=pa.ArrowNotImplementedError, | ||
) | ||
) | ||
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) | ||
|
||
@pytest.mark.parametrize( | ||
|
@@ -800,6 +866,14 @@ def test_argmin_argmax( | |
def test_argreduce_series( | ||
self, data_missing_for_sorting, op_name, skipna, expected, request | ||
): | ||
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"No pyarrow kernel for {pa_dtype}", | ||
raises=pa.ArrowNotImplementedError, | ||
) | ||
) | ||
super().test_argreduce_series( | ||
data_missing_for_sorting, op_name, skipna, expected | ||
) | ||
|
@@ -888,6 +962,21 @@ def test_basic_equals(self, data): | |
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): | ||
divmod_exc = NotImplementedError | ||
|
||
@classmethod | ||
def assert_equal(cls, left, right, **kwargs): | ||
if isinstance(left, pd.DataFrame): | ||
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype | ||
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype | ||
else: | ||
left_pa_type = left.dtype.pyarrow_dtype | ||
right_pa_type = right.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): | ||
# decimal precision can resize in the result type depending on data | ||
# just compare the float values | ||
left = left.astype("float[pyarrow]") | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right = right.astype("float[pyarrow]") | ||
tm.assert_equal(left, right, **kwargs) | ||
|
||
def get_op_from_name(self, op_name): | ||
short_opname = op_name.strip("_") | ||
if short_opname == "rtruediv": | ||
|
@@ -967,7 +1056,11 @@ def _get_scalar_exception(self, opname, pa_dtype): | |
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) | ||
): | ||
exc = None | ||
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): | ||
elif not ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
exc = pa.ArrowNotImplementedError | ||
else: | ||
exc = None | ||
|
@@ -980,7 +1073,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): | |
|
||
if ( | ||
opname == "__rpow__" | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
) | ||
and not pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
|
@@ -998,14 +1095,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): | |
), | ||
) | ||
elif ( | ||
opname in {"__rfloordiv__"} | ||
and pa.types.is_integer(pa_dtype) | ||
opname == "__rfloordiv__" | ||
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype)) | ||
and not pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
elif ( | ||
opname == "__rtruediv__" | ||
and pa.types.is_decimal(pa_dtype) | ||
and not pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
elif ( | ||
opname == "__pow__" | ||
and pa.types.is_decimal(pa_dtype) | ||
and pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="Invalid decimal function: power_checked", | ||
) | ||
|
||
return mark | ||
|
||
|
@@ -1226,6 +1341,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): | |
expected = ArrowDtype(pa.timestamp("s", "UTC")) | ||
assert dtype == expected | ||
|
||
with pytest.raises(NotImplementedError, match="Passing pyarrow type"): | ||
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] | ||
|
@@ -1252,7 +1370,11 @@ def test_quantile(data, interpolation, quantile, request): | |
ser.quantile(q=quantile, interpolation=interpolation) | ||
return | ||
|
||
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): | ||
if ( | ||
pa.types.is_integer(pa_dtype) | ||
or pa.types.is_floating(pa_dtype) | ||
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0) | ||
): | ||
pass | ||
elif pa.types.is_temporal(data._data.type): | ||
pass | ||
|
@@ -1293,7 +1415,11 @@ def test_quantile(data, interpolation, quantile, request): | |
else: | ||
# Just check the values | ||
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5]) | ||
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): | ||
if ( | ||
pa.types.is_integer(pa_dtype) | ||
or pa.types.is_floating(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
expected = expected.astype("float64[pyarrow]") | ||
result = result.astype("float64[pyarrow]") | ||
tm.assert_series_equal(result, expected) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we be more specific than this? e.g. time32 would also go through here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_numeric_dtype(self.dtype)
should block time32. I could import pyarrow and checkpa.types.is_decimal
if you prefer`There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, this is fine