From 01f689eac6913226a7068dd3728a630777baedc9 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 24 Jan 2023 14:32:55 -0800 Subject: [PATCH 01/17] TST: Test ArrowExtensionArray with decimal types --- pandas/_testing/__init__.py | 2 + pandas/core/arrays/arrow/array.py | 1 + pandas/core/arrays/arrow/dtype.py | 2 +- pandas/tests/extension/test_arrow.py | 94 +++++++++++++++++++++++----- 4 files changed, 83 insertions(+), 16 deletions(-) diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index eb25566e7983e..a60831a83e825 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -202,6 +202,7 @@ # pa.float16 doesn't seem supported # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86 FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()] + DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)] STRING_PYARROW_DTYPES = [pa.string()] BINARY_PYARROW_DTYPES = [pa.binary()] @@ -226,6 +227,7 @@ ALL_PYARROW_DTYPES = ( ALL_INT_PYARROW_DTYPES + FLOAT_PYARROW_DTYPES + + DECIMAL_PYARROW_DTYPES + STRING_PYARROW_DTYPES + BINARY_PYARROW_DTYPES + TIME_PYARROW_DTYPES diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0e70b3795bc85..4ecfdade7f589 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1031,6 +1031,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) or pa.types.is_duration(pa_type) + or pa.types.is_decimal(pa_type) ): # pyarrow only supports any/all for boolean dtype, we allow # for other dtypes, matching our non-pyarrow behavior diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index f5f87bea83b8f..950381bd8d8a7 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -146,7 +146,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype: try: pa_dtype = pa.type_for_alias(base_type) except ValueError as err: - has_parameters = re.search(r"\[.*\]", base_type) + has_parameters = re.search(r"[\[\(].*[\]\)]", base_type) if has_parameters: raise NotImplementedError( "Passing pyarrow type specific parameters " diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 2467471e3643e..14972ed1834af 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -16,6 +16,7 @@ time, timedelta, ) +from decimal import Decimal from io import ( BytesIO, StringIO, @@ -77,6 +78,14 @@ def data(dtype): data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99] elif pa.types.is_unsigned_integer(pa_dtype): data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99] + elif pa.types.is_decimal(pa_dtype): + data = ( + [Decimal("1"), Decimal("0.0")] * 4 + + [None] + + [Decimal("-2.0"), Decimal("-1.0")] * 44 + + [None] + + [Decimal("0.5"), Decimal("33.123")] + ) elif pa.types.is_date(pa_dtype): data = ( [date(2022, 1, 1), date(1999, 12, 31)] * 4 @@ -186,6 +195,10 @@ def data_for_grouping(dtype): A = b"a" B = b"b" C = b"c" + elif pa.types.is_decimal(pa_dtype): + A = Decimal("-1.1") + B = Decimal("0.0") + C = Decimal("1.1") else: raise NotImplementedError return pd.array([B, B, None, None, A, A, B, C], dtype=dtype) @@ -248,8 +261,10 @@ def test_astype_str(self, data, request): class TestConstructors(base.BaseConstructorsTests): def test_from_dtype(self, data, request): pa_dtype = data.dtype.pyarrow_dtype - if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string( - pa_dtype + if ( + (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) + or pa.types.is_string(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): if pa.types.is_string(pa_dtype): reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" @@ -262,7 +277,7 @@ def test_from_dtype(self, data, request): ) super().test_from_dtype(data) - def test_from_sequence_pa_array(self, data, request): + def test_from_sequence_pa_array(self, data): # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 # data._data = pa.ChunkedArray result = type(data)._from_sequence(data._data) @@ -294,7 +309,7 @@ def test_from_sequence_of_strings_pa_array(self, data, request): reason="Nanosecond time parsing not supported.", ) ) - elif pa.types.is_duration(pa_dtype): + elif pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, @@ -361,6 +376,8 @@ def test_getitem_scalar(self, data): exp_type = date elif pa.types.is_time(pa_type): exp_type = time + elif pa.types.is_decimal(pa_type): + exp_type = Decimal else: raise NotImplementedError(data.dtype) @@ -434,7 +451,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques raises=NotImplementedError, ) ) - elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)): + elif all_numeric_accumulations == "cumsum" and ( + pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type) + ): request.node.add_marker( pytest.mark.xfail( reason=f"{all_numeric_accumulations} not implemented for {pa_type}", @@ -655,7 +674,9 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request): class TestBaseDtype(base.BaseDtypeTests): def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + if ( + pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None + ) or pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, @@ -676,7 +697,9 @@ def test_construct_from_string_own_name(self, dtype, request): def test_is_dtype_from_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + if ( + pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None + ) or pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, @@ -696,7 +719,9 @@ def test_is_dtype_from_name(self, dtype, request): def test_construct_from_string(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + if ( + pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None + ) or pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, @@ -732,6 +757,7 @@ def test_get_common_dtype(self, dtype, request): or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns") or pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): request.node.add_marker( pytest.mark.xfail( @@ -798,11 +824,13 @@ def test_EA_types(self, engine, data, request): request.node.add_marker( pytest.mark.xfail(raises=TypeError, reason="GH 47534") ) - elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + elif ( + pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None + ) or pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, - reason=f"Parameterized types with tz={pa_dtype.tz} not supported.", + reason=f"Parameterized types {pa_dtype} not supported.", ) ) elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): @@ -1020,6 +1048,18 @@ class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): divmod_exc = NotImplementedError + def assert_equal(cls, left, right, **kwargs): + if isinstance(left, pd.DataFrame): + left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype + right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype + else: + left_pa_type = left.dtype.pyarrow_dtype + right_pa_type = right.dtype.pyarrow_dtype + if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): + left = left.astype("float[pyarrow]") + right = right.astype("float[pyarrow]") + tm.assert_equal(left, right, **kwargs) + def _patch_combine(self, obj, other, op): # BaseOpsUtil._combine can upcast expected dtype # (because it generates expected on python scalars) @@ -1060,7 +1100,11 @@ def _get_scalar_exception(self, opname, pa_dtype): exc = NotImplementedError elif arrow_temporal_supported: exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): + elif not ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): exc = pa.ArrowNotImplementedError else: exc = None @@ -1073,7 +1117,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): if ( opname == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) + and ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ) and not pa_version_under6p0 ): mark = pytest.mark.xfail( @@ -1092,7 +1140,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): ) elif ( opname in {"__rtruediv__", "__rfloordiv__"} - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) + and ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ) and not pa_version_under6p0 ): mark = pytest.mark.xfail( @@ -1199,6 +1251,7 @@ def test_add_series_with_extension_array(self, data, request): if not ( pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) or (not pa_version_under8p0 and pa.types.is_duration(pa_dtype)) ): request.node.add_marker( @@ -1276,6 +1329,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): with pytest.raises(NotImplementedError, match="Passing pyarrow type"): ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") + with pytest.raises(NotImplementedError, match="Passing pyarrow type"): + ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]") + @pytest.mark.parametrize( "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] @@ -1302,7 +1358,11 @@ def test_quantile(data, interpolation, quantile, request): ser.quantile(q=quantile, interpolation=interpolation) return - if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): pass elif pa.types.is_temporal(data._data.type) and interpolation in ["lower", "higher"]: pass @@ -1321,7 +1381,11 @@ def test_quantile(data, interpolation, quantile, request): else: # Just check the values expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5]) - if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): expected = expected.astype("float64[pyarrow]") result = result.astype("float64[pyarrow]") tm.assert_series_equal(result, expected) From 5c5e28bfbf6dece3dd5206ace7f824aadb1dba8e Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 24 Jan 2023 16:07:04 -0800 Subject: [PATCH 02/17] Version compat --- pandas/tests/extension/test_arrow.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 14972ed1834af..59cc4650b381b 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -542,6 +542,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): and pa_version_under6p0 ): request.node.add_marker(xfail_mark) + elif ( + all_numeric_reductions in {"var", "std", "median"} + and pa_version_under7p0 + and pa.types.is_decimal(pa_dtype) + ): + request.node.add_marker(xfail_mark) elif all_numeric_reductions == "sem" and pa_version_under8p0: request.node.add_marker(xfail_mark) elif ( @@ -1361,7 +1367,7 @@ def test_quantile(data, interpolation, quantile, request): if ( pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype) - or pa.types.is_decimal(pa_dtype) + or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0) ): pass elif pa.types.is_temporal(data._data.type) and interpolation in ["lower", "higher"]: @@ -1411,6 +1417,13 @@ def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): reason=f"mode not supported by pyarrow for {pa_dtype}", ) ) + elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"mode not supported by pyarrow for {pa_dtype}", + ) + ) elif ( pa.types.is_boolean(pa_dtype) and "multi_mode" in request.node.nodeid From c2f108d65da0fe0a2730a6202a11a6a21223d28d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:32:13 -0800 Subject: [PATCH 03/17] Add other xfails based on min version --- pandas/tests/extension/test_arrow.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index bb5e935db6403..283fbda4f02a6 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -678,6 +678,17 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request): class TestBaseDtype(base.BaseDtypeTests): + def test_check_dtype(self, data, request): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + raises=TypeError, + reason="decimal string repr affects numpy comparison", + ) + ) + super().test_construct_from_string_own_name(dtype) + def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype if ( @@ -926,6 +937,13 @@ def test_argmin_argmax( reason=f"{pa_dtype} only has 2 unique possible values", ) ) + elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + reason=f"No pyarrow kernel for {pa_dtype}", + raises=pa.ArrowNotImplementedError, + ) + ) super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) @pytest.mark.parametrize( @@ -944,6 +962,7 @@ def test_argmin_argmax( def test_argreduce_series( self, data_missing_for_sorting, op_name, skipna, expected, request ): + pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype if pa_version_under6p0 and skipna: request.node.add_marker( pytest.mark.xfail( @@ -951,6 +970,13 @@ def test_argreduce_series( reason="min_max not supported in pyarrow", ) ) + elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + reason=f"No pyarrow kernel for {pa_dtype}", + raises=pa.ArrowNotImplementedError, + ) + ) super().test_argreduce_series( data_missing_for_sorting, op_name, skipna, expected ) @@ -1050,6 +1076,7 @@ class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): divmod_exc = NotImplementedError + @classmethod def assert_equal(cls, left, right, **kwargs): if isinstance(left, pd.DataFrame): left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype @@ -1153,6 +1180,15 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): raises=pa.ArrowInvalid, reason="divide by 0", ) + elif ( + opname == "__pow__" + and pa.types.is_decimal(pa_dtype) + and pa_version_under7p0 + ): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="Invalid decimal function: power_checked", + ) return mark From 60ccab11bbb71e95047e03f9d6915d4497031736 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 25 Jan 2023 16:30:55 -0800 Subject: [PATCH 04/17] fix test --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 283fbda4f02a6..c803b2a84b145 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -679,7 +679,7 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request): class TestBaseDtype(base.BaseDtypeTests): def test_check_dtype(self, data, request): - pa_dtype = dtype.pyarrow_dtype + pa_dtype = data.dtype.pyarrow_dtype if pa.types.is_decimal(pa_dtype) and pa_version_under7p0: request.node.add_marker( pytest.mark.xfail( From d63b1ea0aded56ca51075c9c43abf1353442204c Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 25 Jan 2023 18:30:28 -0800 Subject: [PATCH 05/17] fix typo --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 342b834f2f12e..db3c0e959595e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -706,7 +706,7 @@ def test_check_dtype(self, data, request): reason="decimal string repr affects numpy comparison", ) ) - super().test_construct_from_string_own_name(dtype) + super().test_check_dtype(dtype) def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype From f944e849b7d24aeb612211c012cfe0da93bf743b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 26 Jan 2023 10:05:24 -0800 Subject: [PATCH 06/17] another typo --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index db3c0e959595e..a59ffe33004f7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -706,7 +706,7 @@ def test_check_dtype(self, data, request): reason="decimal string repr affects numpy comparison", ) ) - super().test_check_dtype(dtype) + super().test_check_dtype(data) def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype From 7f42ad30ea2bb7f3ca69a8dea4bb7e3a91451cc0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 27 Jan 2023 17:10:59 -0800 Subject: [PATCH 07/17] only for skipna --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 803a8fee43a10..8d86be28fbaed 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -991,7 +991,7 @@ def test_argreduce_series( reason="min_max not supported in pyarrow", ) ) - elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna: request.node.add_marker( pytest.mark.xfail( reason=f"No pyarrow kernel for {pa_dtype}", From a16ee3b765ce6f9dcac7b0318f04b70a33f2078f Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 1 Feb 2023 11:33:37 -0800 Subject: [PATCH 08/17] Add comment --- pandas/tests/extension/test_arrow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 130a6d05dfad1..34a71b47118ef 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1118,6 +1118,8 @@ def assert_equal(cls, left, right, **kwargs): left_pa_type = left.dtype.pyarrow_dtype right_pa_type = right.dtype.pyarrow_dtype if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): + # decimal precision can resize in the result type depending on data + # just compare the float values left = left.astype("float[pyarrow]") right = right.astype("float[pyarrow]") tm.assert_equal(left, right, **kwargs) From 129ac08149f7cfae05670d1bcd4d53f9a3c58b94 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 2 Feb 2023 10:48:22 -0800 Subject: [PATCH 09/17] Fix --- pandas/tests/extension/test_arrow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 0fcf39184eda2..8ff5a1751d8d0 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1377,7 +1377,6 @@ def test_add_series_with_extension_array(self, data, request): pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype) or pa.types.is_boolean(pa_dtype) - or pa.types.is_decimal(pa_dtype) ): request.node.add_marker( pytest.mark.xfail( From c44f260a7a1bb556b504e669dd032c0675534e24 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 7 Feb 2023 12:49:50 -0800 Subject: [PATCH 10/17] undo comments --- pandas/tests/extension/test_arrow.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index aafa7b798a5ad..45c5a987f493e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -529,17 +529,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): ) if all_numeric_reductions in {"skew", "kurt"}: request.node.add_marker(xfail_mark) - # elif ( - # all_numeric_reductions in {"median", "var", "std", "prod", "max", "min"} - # and pa_version_under6p0 - # ): - # request.node.add_marker(xfail_mark) - # elif ( - # all_numeric_reductions in {"var", "std", "median"} - # and pa_version_under7p0 - # and pa.types.is_decimal(pa_dtype) - # ): - # request.node.add_marker(xfail_mark) + elif ( + all_numeric_reductions in {"var", "std", "median"} + and pa_version_under7p0 + and pa.types.is_decimal(pa_dtype) + ): + request.node.add_marker(xfail_mark) elif all_numeric_reductions == "sem" and pa_version_under8p0: request.node.add_marker(xfail_mark) From 406bb697fcc3fb29eabe435e2f46f95a084187de Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:40:23 -0800 Subject: [PATCH 11/17] Bump version condition --- pandas/tests/extension/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 5e3322ade4e17..9a58dbe8f0b42 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -666,7 +666,7 @@ def test_in_numeric_groupby(self, data_for_grouping): class TestBaseDtype(base.BaseDtypeTests): def test_check_dtype(self, data, request): pa_dtype = data.dtype.pyarrow_dtype - if pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + if pa.types.is_decimal(pa_dtype) and pa_version_under8p0: request.node.add_marker( pytest.mark.xfail( raises=ValueError, From dfe4bb0ca2a430f4c495e4d9049fdc709c1ac1c8 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 17 Feb 2023 14:31:31 -0800 Subject: [PATCH 12/17] Skip masked indexing engine for decimal --- pandas/core/indexes/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index bd631c0c0d948..acebe8a498f03 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -810,7 +810,11 @@ def _engine( target_values = self._get_engine_target() if isinstance(target_values, ExtensionArray): if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)): - return _masked_engines[target_values.dtype.name](target_values) + try: + return _masked_engines[target_values.dtype.name](target_values) + except KeyError: + # Not supported yet e.g. decimal + pass elif self._engine_type is libindex.ObjectEngine: return libindex.ExtensionEngine(target_values) @@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike: and not ( isinstance(self._values, ArrowExtensionArray) and is_numeric_dtype(self.dtype) + # Exclude decimal + and self.dtype.kind != "O" ) ): # TODO(ExtensionIndex): remove special-case, just use self._values From ff53f6d6c2e23232a348067636c849ce52d029bb Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 21 Feb 2023 11:05:59 -0800 Subject: [PATCH 13/17] Some merge stuff --- pandas/tests/extension/test_arrow.py | 51 +++++++--------------------- 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index b14ea35931c24..82e5d8682c8f4 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -636,16 +636,6 @@ def test_construct_from_string_own_name(self, dtype, request): reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", ) ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), - ) - ) if pa.types.is_string(pa_dtype): # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) @@ -659,45 +649,28 @@ def test_construct_from_string_own_name(self, dtype, request): def test_is_dtype_from_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if pa.types.is_decimal(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", - ) - ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), + if pa.types.is_string(pa_dtype): + # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) + assert not type(dtype).is_dtype(dtype.name) + else: + if pa.types.is_decimal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", + ) ) - ) - super().test_is_dtype_from_name(dtype) + super().test_is_dtype_from_name(dtype) def test_construct_from_string(self, dtype, request): pa_dtype = dtype.pyarrow_dtype - if ( - pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None - ) or pa.types.is_decimal(pa_dtype): + if pa.types.is_decimal(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=NotImplementedError, reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", ) ) - elif pa.types.is_string(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - "Still support StringDtype('pyarrow') " - "over ArrowDtype(pa.string())" - ), - ) - ) super().test_construct_from_string(dtype) def test_construct_from_string_another_type_raises(self, dtype): From 611439170a1c1c405070c30ea442cffc6aa87bb9 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 21 Feb 2023 17:54:03 -0800 Subject: [PATCH 14/17] Remove imaginary test --- pandas/tests/extension/test_arrow.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 82e5d8682c8f4..12d2bf996b2a1 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -662,17 +662,6 @@ def test_is_dtype_from_name(self, dtype, request): ) super().test_is_dtype_from_name(dtype) - def test_construct_from_string(self, dtype, request): - pa_dtype = dtype.pyarrow_dtype - if pa.types.is_decimal(pa_dtype): - request.node.add_marker( - pytest.mark.xfail( - raises=NotImplementedError, - reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", - ) - ) - super().test_construct_from_string(dtype) - def test_construct_from_string_another_type_raises(self, dtype): msg = r"'another_type' must end with '\[pyarrow\]'" with pytest.raises(TypeError, match=msg): From 4f8fadc08634e32a7af4f98fb3c4dcca1df28794 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 21 Feb 2023 19:12:51 -0800 Subject: [PATCH 15/17] Fix condition --- pandas/tests/extension/test_arrow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 12d2bf996b2a1..5888a31e07806 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -301,9 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request): reason="Nanosecond time parsing not supported.", ) ) - elif ( - pa_version_under11p0 and pa.types.is_duration(pa_dtype) - ) or pa.types.is_decimal(pa_dtype): + elif pa_version_under11p0 and ( + pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype) + ): request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, From deaa1db01dfc96bb2e226af466f8f68a5e2654f0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 21 Feb 2023 19:16:25 -0800 Subject: [PATCH 16/17] Fix another test --- pandas/tests/indexes/test_common.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index 40440bd8e0ff8..d8e17e48a19b3 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat): @pytest.mark.parametrize("na_position", [None, "middle"]) def test_sort_values_invalid_na_position(index_with_missing, na_position): with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"): - with tm.maybe_produces_warning( - PerformanceWarning, - getattr(index_with_missing.dtype, "storage", "") == "pyarrow", - check_stacklevel=False, - ): - index_with_missing.sort_values(na_position=na_position) + index_with_missing.sort_values(na_position=na_position) @pytest.mark.parametrize("na_position", ["first", "last"]) From 142c7719bd987c0c2649cd5c818ab61a7bc5d236 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 22 Feb 2023 09:24:41 -0800 Subject: [PATCH 17/17] Update condition --- pandas/tests/extension/test_arrow.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a931e568a8843..8ccf63541658c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1095,8 +1095,17 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): ), ) elif ( - opname in {"__rfloordiv__"} - and pa.types.is_integer(pa_dtype) + opname == "__rfloordiv__" + and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype)) + and not pa_version_under7p0 + ): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", + ) + elif ( + opname == "__rtruediv__" + and pa.types.is_decimal(pa_dtype) and not pa_version_under7p0 ): mark = pytest.mark.xfail(