diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 00e949b1dd318..69ca809e4f498 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -215,6 +215,7 @@ FLOAT_PYARROW_DTYPES_STR_REPR = [ str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES ] + DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)] STRING_PYARROW_DTYPES = [pa.string()] BINARY_PYARROW_DTYPES = [pa.binary()] @@ -239,6 +240,7 @@ ALL_PYARROW_DTYPES = ( ALL_INT_PYARROW_DTYPES + FLOAT_PYARROW_DTYPES + + DECIMAL_PYARROW_DTYPES + STRING_PYARROW_DTYPES + BINARY_PYARROW_DTYPES + TIME_PYARROW_DTYPES diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index ab4f87d4cf374..5c09a7a28856f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1098,6 +1098,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) or pa.types.is_duration(pa_type) + or pa.types.is_decimal(pa_type) ): # pyarrow only supports any/all for boolean dtype, we allow # for other dtypes, matching our non-pyarrow behavior diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index fdb9ac877831b..e6c1b2107e7f6 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -201,7 +201,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype: try: pa_dtype = pa.type_for_alias(base_type) except ValueError as err: - has_parameters = re.search(r"\[.*\]", base_type) + has_parameters = re.search(r"[\[\(].*[\]\)]", base_type) if has_parameters: # Fallback to try common temporal types try: diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index bd631c0c0d948..acebe8a498f03 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -810,7 +810,11 @@ def _engine( target_values = self._get_engine_target() if isinstance(target_values, ExtensionArray): if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)): - return _masked_engines[target_values.dtype.name](target_values) + try: + return _masked_engines[target_values.dtype.name](target_values) + except KeyError: + # Not supported yet e.g. decimal + pass elif self._engine_type is libindex.ObjectEngine: return libindex.ExtensionEngine(target_values) @@ -4948,6 +4952,8 @@ def _get_engine_target(self) -> ArrayLike: and not ( isinstance(self._values, ArrowExtensionArray) and is_numeric_dtype(self.dtype) + # Exclude decimal + and self.dtype.kind != "O" ) ): # TODO(ExtensionIndex): remove special-case, just use self._values diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 75e5744f498b4..8ccf63541658c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -16,6 +16,7 @@ time, timedelta, ) +from decimal import Decimal from io import ( BytesIO, StringIO, @@ -79,6 +80,14 @@ def data(dtype): data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99] elif pa.types.is_unsigned_integer(pa_dtype): data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99] + elif pa.types.is_decimal(pa_dtype): + data = ( + [Decimal("1"), Decimal("0.0")] * 4 + + [None] + + [Decimal("-2.0"), Decimal("-1.0")] * 44 + + [None] + + [Decimal("0.5"), Decimal("33.123")] + ) elif pa.types.is_date(pa_dtype): data = ( [date(2022, 1, 1), date(1999, 12, 31)] * 4 @@ -188,6 +197,10 @@ def data_for_grouping(dtype): A = b"a" B = b"b" C = b"c" + elif pa.types.is_decimal(pa_dtype): + A = Decimal("-1.1") + B = Decimal("0.0") + C = Decimal("1.1") else: raise NotImplementedError return pd.array([B, B, None, None, A, A, B, C], dtype=dtype) @@ -250,9 +263,12 @@ def test_astype_str(self, data, request): class TestConstructors(base.BaseConstructorsTests): def test_from_dtype(self, data, request): pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype): + if pa.types.is_string(pa_dtype): + reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" + else: + reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}" - if pa.types.is_string(pa_dtype): - reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" request.node.add_marker( pytest.mark.xfail( reason=reason, @@ -260,7 +276,7 @@ def test_from_dtype(self, data, request): ) super().test_from_dtype(data) - def test_from_sequence_pa_array(self, data, request): + def test_from_sequence_pa_array(self, data): # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 # data._data = pa.ChunkedArray result = type(data)._from_sequence(data._data) @@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request): reason="Nanosecond time parsing not supported.", ) ) - elif pa_version_under11p0 and pa.types.is_duration(pa_dtype): + elif pa_version_under11p0 and ( + pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype) + ): request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, @@ -384,7 +402,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques # renders the exception messages even when not showing them pytest.skip(f"{all_numeric_accumulations} not implemented for pyarrow < 9") - elif all_numeric_accumulations == "cumsum" and pa.types.is_boolean(pa_type): + elif all_numeric_accumulations == "cumsum" and ( + pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type) + ): request.node.add_marker( pytest.mark.xfail( reason=f"{all_numeric_accumulations} not implemented for {pa_type}", @@ -468,6 +488,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): ) if all_numeric_reductions in {"skew", "kurt"}: request.node.add_marker(xfail_mark) + elif ( + all_numeric_reductions in {"var", "std", "median"} + and pa_version_under7p0 + and pa.types.is_decimal(pa_dtype) + ): + request.node.add_marker(xfail_mark) elif all_numeric_reductions == "sem" and pa_version_under8p0: request.node.add_marker(xfail_mark) @@ -590,8 +616,26 @@ def test_in_numeric_groupby(self, data_for_grouping): class TestBaseDtype(base.BaseDtypeTests): + def test_check_dtype(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype) and pa_version_under8p0: + request.node.add_marker( + pytest.mark.xfail( + raises=ValueError, + reason="decimal string repr affects numpy comparison", + ) + ) + super().test_check_dtype(data) + def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", + ) + ) if pa.types.is_string(pa_dtype): # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) @@ -609,6 +653,13 @@ def test_is_dtype_from_name(self, dtype, request): # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) assert not type(dtype).is_dtype(dtype.name) else: + if pa.types.is_decimal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", + ) + ) super().test_is_dtype_from_name(dtype) def test_construct_from_string_another_type_raises(self, dtype): @@ -627,6 +678,7 @@ def test_get_common_dtype(self, dtype, request): ) or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns") or pa.types.is_binary(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): request.node.add_marker( pytest.mark.xfail( @@ -700,6 +752,13 @@ def test_EA_types(self, engine, data, request): request.node.add_marker( pytest.mark.xfail(raises=TypeError, reason="GH 47534") ) + elif pa.types.is_decimal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"Parameterized types {pa_dtype} not supported.", + ) + ) elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): request.node.add_marker( pytest.mark.xfail( @@ -782,6 +841,13 @@ def test_argmin_argmax( reason=f"{pa_dtype} only has 2 unique possible values", ) ) + elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + reason=f"No pyarrow kernel for {pa_dtype}", + raises=pa.ArrowNotImplementedError, + ) + ) super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value) @pytest.mark.parametrize( @@ -800,6 +866,14 @@ def test_argmin_argmax( def test_argreduce_series( self, data_missing_for_sorting, op_name, skipna, expected, request ): + pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna: + request.node.add_marker( + pytest.mark.xfail( + reason=f"No pyarrow kernel for {pa_dtype}", + raises=pa.ArrowNotImplementedError, + ) + ) super().test_argreduce_series( data_missing_for_sorting, op_name, skipna, expected ) @@ -888,6 +962,21 @@ def test_basic_equals(self, data): class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): divmod_exc = NotImplementedError + @classmethod + def assert_equal(cls, left, right, **kwargs): + if isinstance(left, pd.DataFrame): + left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype + right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype + else: + left_pa_type = left.dtype.pyarrow_dtype + right_pa_type = right.dtype.pyarrow_dtype + if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): + # decimal precision can resize in the result type depending on data + # just compare the float values + left = left.astype("float[pyarrow]") + right = right.astype("float[pyarrow]") + tm.assert_equal(left, right, **kwargs) + def get_op_from_name(self, op_name): short_opname = op_name.strip("_") if short_opname == "rtruediv": @@ -967,7 +1056,11 @@ def _get_scalar_exception(self, opname, pa_dtype): pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): + elif not ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): exc = pa.ArrowNotImplementedError else: exc = None @@ -980,7 +1073,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): if ( opname == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) + and ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ) and not pa_version_under7p0 ): mark = pytest.mark.xfail( @@ -998,14 +1095,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): ), ) elif ( - opname in {"__rfloordiv__"} - and pa.types.is_integer(pa_dtype) + opname == "__rfloordiv__" + and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype)) and not pa_version_under7p0 ): mark = pytest.mark.xfail( raises=pa.ArrowInvalid, reason="divide by 0", ) + elif ( + opname == "__rtruediv__" + and pa.types.is_decimal(pa_dtype) + and not pa_version_under7p0 + ): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", + ) + elif ( + opname == "__pow__" + and pa.types.is_decimal(pa_dtype) + and pa_version_under7p0 + ): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="Invalid decimal function: power_checked", + ) return mark @@ -1226,6 +1341,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): expected = ArrowDtype(pa.timestamp("s", "UTC")) assert dtype == expected + with pytest.raises(NotImplementedError, match="Passing pyarrow type"): + ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]") + @pytest.mark.parametrize( "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] @@ -1252,7 +1370,11 @@ def test_quantile(data, interpolation, quantile, request): ser.quantile(q=quantile, interpolation=interpolation) return - if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0) + ): pass elif pa.types.is_temporal(data._data.type): pass @@ -1293,7 +1415,11 @@ def test_quantile(data, interpolation, quantile, request): else: # Just check the values expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5]) - if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): expected = expected.astype("float64[pyarrow]") result = result.astype("float64[pyarrow]") tm.assert_series_equal(result, expected) diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index 40440bd8e0ff8..d8e17e48a19b3 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -449,12 +449,7 @@ def test_hasnans_isnans(self, index_flat): @pytest.mark.parametrize("na_position", [None, "middle"]) def test_sort_values_invalid_na_position(index_with_missing, na_position): with pytest.raises(ValueError, match=f"invalid na_position: {na_position}"): - with tm.maybe_produces_warning( - PerformanceWarning, - getattr(index_with_missing.dtype, "storage", "") == "pyarrow", - check_stacklevel=False, - ): - index_with_missing.sort_values(na_position=na_position) + index_with_missing.sort_values(na_position=na_position) @pytest.mark.parametrize("na_position", ["first", "last"])