-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
TST: Test ArrowExtensionArray with decimal types #50964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mroeschke
merged 39 commits into
pandas-dev:main
from
mroeschke:enh/arrow/decimal_testing
Feb 22, 2023
Merged
Changes from 27 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
01f689e
TST: Test ArrowExtensionArray with decimal types
mroeschke e90cc24
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 5c5e28b
Version compat
mroeschke 68e769b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke f799c74
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke c2f108d
Add other xfails based on min version
mroeschke b67a13e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 60ccab1
fix test
mroeschke e9a1b4e
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke d63b1ea
fix typo
mroeschke c2a7610
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke f944e84
another typo
mroeschke 7758f61
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 5492d11
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 7f42ad3
only for skipna
mroeschke 806abe9
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke a16ee3b
Add comment
mroeschke fe5800a
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 44f2eb8
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 129ac08
Fix
mroeschke 13128ac
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke c44f260
undo comments
mroeschke f86666d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke c6f72d3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 406bb69
Bump version condition
mroeschke fb6992b
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 47fc864
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 6a2923c
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 28006b2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 24e52e2
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke dfe4bb0
Skip masked indexing engine for decimal
mroeschke 798f9a3
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke ff53f6d
Some merge stuff
mroeschke a4143df
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 6114391
Remove imaginary test
mroeschke 4f8fadc
Fix condition
mroeschke deaa1db
Fix another test
mroeschke 3ce8c9d
Merge remote-tracking branch 'upstream/main' into enh/arrow/decimal_t…
mroeschke 142c771
Update condition
mroeschke File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
time, | ||
timedelta, | ||
) | ||
from decimal import Decimal | ||
from io import ( | ||
BytesIO, | ||
StringIO, | ||
|
@@ -77,6 +78,14 @@ def data(dtype): | |
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99] | ||
elif pa.types.is_unsigned_integer(pa_dtype): | ||
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99] | ||
elif pa.types.is_decimal(pa_dtype): | ||
data = ( | ||
[Decimal("1"), Decimal("0.0")] * 4 | ||
+ [None] | ||
+ [Decimal("-2.0"), Decimal("-1.0")] * 44 | ||
+ [None] | ||
+ [Decimal("0.5"), Decimal("33.123")] | ||
) | ||
elif pa.types.is_date(pa_dtype): | ||
data = ( | ||
[date(2022, 1, 1), date(1999, 12, 31)] * 4 | ||
|
@@ -186,6 +195,10 @@ def data_for_grouping(dtype): | |
A = b"a" | ||
B = b"b" | ||
C = b"c" | ||
elif pa.types.is_decimal(pa_dtype): | ||
A = Decimal("-1.1") | ||
B = Decimal("0.0") | ||
C = Decimal("1.1") | ||
else: | ||
raise NotImplementedError | ||
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype) | ||
|
@@ -248,8 +261,10 @@ def test_astype_str(self, data, request): | |
class TestConstructors(base.BaseConstructorsTests): | ||
def test_from_dtype(self, data, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string( | ||
pa_dtype | ||
if ( | ||
(pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) | ||
or pa.types.is_string(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
if pa.types.is_string(pa_dtype): | ||
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" | ||
|
@@ -262,7 +277,7 @@ def test_from_dtype(self, data, request): | |
) | ||
super().test_from_dtype(data) | ||
|
||
def test_from_sequence_pa_array(self, data, request): | ||
def test_from_sequence_pa_array(self, data): | ||
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 | ||
# data._data = pa.ChunkedArray | ||
result = type(data)._from_sequence(data._data) | ||
|
@@ -287,7 +302,7 @@ def test_from_sequence_of_strings_pa_array(self, data, request): | |
reason="Nanosecond time parsing not supported.", | ||
) | ||
) | ||
elif pa.types.is_duration(pa_dtype): | ||
elif pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowNotImplementedError, | ||
|
@@ -347,6 +362,8 @@ def test_getitem_scalar(self, data): | |
exp_type = date | ||
elif pa.types.is_time(pa_type): | ||
exp_type = time | ||
elif pa.types.is_decimal(pa_type): | ||
exp_type = Decimal | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise NotImplementedError(data.dtype) | ||
|
||
|
@@ -426,7 +443,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques | |
raises=NotImplementedError, | ||
) | ||
) | ||
elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)): | ||
elif all_numeric_accumulations == "cumsum" and ( | ||
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type) | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"{all_numeric_accumulations} not implemented for {pa_type}", | ||
|
@@ -510,6 +529,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): | |
) | ||
if all_numeric_reductions in {"skew", "kurt"}: | ||
request.node.add_marker(xfail_mark) | ||
elif ( | ||
all_numeric_reductions in {"var", "std", "median"} | ||
and pa_version_under7p0 | ||
and pa.types.is_decimal(pa_dtype) | ||
): | ||
request.node.add_marker(xfail_mark) | ||
elif all_numeric_reductions == "sem" and pa_version_under8p0: | ||
request.node.add_marker(xfail_mark) | ||
|
||
|
@@ -632,9 +657,22 @@ def test_in_numeric_groupby(self, data_for_grouping): | |
|
||
|
||
class TestBaseDtype(base.BaseDtypeTests): | ||
def test_check_dtype(self, data, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=ValueError, | ||
reason="decimal string repr affects numpy comparison", | ||
) | ||
) | ||
super().test_check_dtype(data) | ||
|
||
def test_construct_from_string_own_name(self, dtype, request): | ||
pa_dtype = dtype.pyarrow_dtype | ||
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: | ||
if ( | ||
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None | ||
) or pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
|
@@ -655,7 +693,9 @@ def test_construct_from_string_own_name(self, dtype, request): | |
|
||
def test_is_dtype_from_name(self, dtype, request): | ||
pa_dtype = dtype.pyarrow_dtype | ||
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: | ||
if ( | ||
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None | ||
) or pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
|
@@ -675,7 +715,9 @@ def test_is_dtype_from_name(self, dtype, request): | |
|
||
def test_construct_from_string(self, dtype, request): | ||
pa_dtype = dtype.pyarrow_dtype | ||
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: | ||
if ( | ||
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None | ||
) or pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
|
@@ -710,6 +752,7 @@ def test_get_common_dtype(self, dtype, request): | |
) | ||
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns") | ||
or pa.types.is_binary(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
|
@@ -783,11 +826,13 @@ def test_EA_types(self, engine, data, request): | |
request.node.add_marker( | ||
pytest.mark.xfail(raises=TypeError, reason="GH 47534") | ||
) | ||
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: | ||
elif ( | ||
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None | ||
) or pa.types.is_decimal(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=NotImplementedError, | ||
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.", | ||
reason=f"Parameterized types {pa_dtype} not supported.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #50689 fixes this for timestamptz, after both this and that are merged we can see about fixing this for decimal. |
||
) | ||
) | ||
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): | ||
|
@@ -872,6 +917,13 @@ def test_argmin_argmax( | |
reason=f"{pa_dtype} only has 2 unique possible values", | ||
) | ||
) | ||
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"No pyarrow kernel for {pa_dtype}", | ||
raises=pa.ArrowNotImplementedError, | ||
) | ||
) | ||
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) | ||
|
||
@pytest.mark.parametrize( | ||
|
@@ -890,6 +942,14 @@ def test_argmin_argmax( | |
def test_argreduce_series( | ||
self, data_missing_for_sorting, op_name, skipna, expected, request | ||
): | ||
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=f"No pyarrow kernel for {pa_dtype}", | ||
raises=pa.ArrowNotImplementedError, | ||
) | ||
) | ||
super().test_argreduce_series( | ||
data_missing_for_sorting, op_name, skipna, expected | ||
) | ||
|
@@ -989,6 +1049,21 @@ class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): | |
|
||
divmod_exc = NotImplementedError | ||
|
||
@classmethod | ||
def assert_equal(cls, left, right, **kwargs): | ||
if isinstance(left, pd.DataFrame): | ||
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype | ||
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype | ||
else: | ||
left_pa_type = left.dtype.pyarrow_dtype | ||
right_pa_type = right.dtype.pyarrow_dtype | ||
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): | ||
# decimal precision can resize in the result type depending on data | ||
# just compare the float values | ||
left = left.astype("float[pyarrow]") | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right = right.astype("float[pyarrow]") | ||
tm.assert_equal(left, right, **kwargs) | ||
|
||
def _patch_combine(self, obj, other, op): | ||
# BaseOpsUtil._combine can upcast expected dtype | ||
# (because it generates expected on python scalars) | ||
|
@@ -1044,7 +1119,11 @@ def _get_scalar_exception(self, opname, pa_dtype): | |
exc = NotImplementedError | ||
elif arrow_temporal_supported: | ||
exc = None | ||
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): | ||
elif not ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
exc = pa.ArrowNotImplementedError | ||
else: | ||
exc = None | ||
|
@@ -1057,7 +1136,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): | |
|
||
if ( | ||
opname == "__rpow__" | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
) | ||
and not pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
|
@@ -1076,13 +1159,26 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): | |
) | ||
elif ( | ||
opname in {"__rtruediv__", "__rfloordiv__"} | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
) | ||
and not pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
elif ( | ||
opname == "__pow__" | ||
and pa.types.is_decimal(pa_dtype) | ||
and pa_version_under7p0 | ||
): | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="Invalid decimal function: power_checked", | ||
) | ||
|
||
return mark | ||
|
||
|
@@ -1297,6 +1393,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): | |
with pytest.raises(NotImplementedError, match="Passing pyarrow type"): | ||
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") | ||
|
||
with pytest.raises(NotImplementedError, match="Passing pyarrow type"): | ||
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] | ||
|
@@ -1323,7 +1422,11 @@ def test_quantile(data, interpolation, quantile, request): | |
ser.quantile(q=quantile, interpolation=interpolation) | ||
return | ||
|
||
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): | ||
if ( | ||
pa.types.is_integer(pa_dtype) | ||
or pa.types.is_floating(pa_dtype) | ||
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0) | ||
): | ||
pass | ||
elif pa.types.is_temporal(data._data.type): | ||
pass | ||
|
@@ -1364,7 +1467,11 @@ def test_quantile(data, interpolation, quantile, request): | |
else: | ||
# Just check the values | ||
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5]) | ||
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): | ||
if ( | ||
pa.types.is_integer(pa_dtype) | ||
or pa.types.is_floating(pa_dtype) | ||
or pa.types.is_decimal(pa_dtype) | ||
): | ||
expected = expected.astype("float64[pyarrow]") | ||
result = result.astype("float64[pyarrow]") | ||
tm.assert_series_equal(result, expected) | ||
|
@@ -1385,6 +1492,13 @@ def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): | |
reason=f"mode not supported by pyarrow for {pa_dtype}", | ||
) | ||
) | ||
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowNotImplementedError, | ||
reason=f"mode not supported by pyarrow for {pa_dtype}", | ||
) | ||
) | ||
elif ( | ||
pa.types.is_boolean(pa_dtype) | ||
and "multi_mode" in request.node.nodeid | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like if you pass Decimal("nan") to pyarrow it raises. is that going to be relevant for us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we coerce to NA
Which I guess is fuzzy with the NA vs nan debate