From 187023a59f49440f39c3e79b81dc149ad5705648 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 11 Jan 2023 17:21:06 -0800 Subject: [PATCH 1/5] TST: simplify pyarrow tests, make mode work with temporal dtypes --- pandas/core/arrays/arrow/array.py | 9 + pandas/tests/extension/test_arrow.py | 280 ++++++++++++--------------- 2 files changed, 132 insertions(+), 157 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index de85ed67e7e8c..3f3a6f1d5f267 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1209,6 +1209,15 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra """ if pa_version_under6p0: raise NotImplementedError("mode only supported for pyarrow version >= 6.0") + + pa_type = self._data.type + if pa.types.is_temporal(pa_type): + nbits = pa_type.bit_width + dtype = f"int{nbits}[pyarrow]" + obj = self.astype(dtype, copy=False) + result = obj._mode(dropna=dropna) + return result.astype(self.dtype, copy=False) + modes = pc.mode(self._data, pc.count_distinct(self._data).as_py()) values = modes.field(0) counts = modes.field(1) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 78c49ae066288..b389ed0b59f82 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -333,15 +333,39 @@ def test_from_sequence_of_strings_pa_array(self, data, request): class TestGetitemTests(base.BaseGetitemTests): - @pytest.mark.xfail( - reason=( - "data.dtype.type return pyarrow.DataType " - "but this (intentionally) returns " - "Python scalars or pd.NA" - ) - ) def test_getitem_scalar(self, data): - super().test_getitem_scalar(data) + # In the base class we expect data.dtype.type; but this (intentionally) + # returns Python scalars or pd.NA + pa_type = data._data.type + if pa.types.is_integer(pa_type): + exp_type = int + elif pa.types.is_floating(pa_type): + exp_type = float + elif pa.types.is_string(pa_type): + exp_type = str + elif pa.types.is_binary(pa_type): + exp_type = bytes + elif pa.types.is_boolean(pa_type): + exp_type = bool + elif pa.types.is_duration(pa_type): + exp_type = timedelta + elif pa.types.is_timestamp(pa_type): + if pa_type.unit == "ns": + exp_type = pd.Timestamp + else: + exp_type = datetime + elif pa.types.is_date(pa_type): + exp_type = date + elif pa.types.is_time(pa_type): + exp_type = time + else: + raise NotImplementedError(data.dtype) + + result = data[0] + assert isinstance(result, exp_type), type(result) + + result = pd.Series(data)[0] + assert isinstance(result, exp_type), type(result) class TestBaseAccumulateTests(base.BaseAccumulateTests): @@ -1054,70 +1078,83 @@ def _patch_combine(self, obj, other, op): expected = pd.Series(pd_array) return expected - def test_arith_series_with_scalar( - self, data, all_arithmetic_operators, request, monkeypatch - ): - pa_dtype = data.dtype.pyarrow_dtype - - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") + def _is_temporal_supported(self, opname, pa_dtype): + return not pa_version_under8p0 and ( + opname in ("__add__", "__radd__") and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") + or opname in ("__sub__", "__rsub__") and pa.types.is_temporal(pa_dtype) ) - if all_arithmetic_operators == "__rmod__" and ( - pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) - ): - pytest.skip("Skip testing Python string formatting") - elif all_arithmetic_operators in { + + def _get_scalar_exception(self, opname, pa_dtype): + arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) + if opname in { "__mod__", "__rmod__", }: - self.series_scalar_exc = NotImplementedError + exc = NotImplementedError elif arrow_temporal_supported: - self.series_scalar_exc = None - elif not ( - pa.types.is_floating(pa_dtype) - or pa.types.is_integer(pa_dtype) - or arrow_temporal_supported - ): - self.series_scalar_exc = pa.ArrowNotImplementedError + exc = None + elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): + exc = pa.ArrowNotImplementedError else: - self.series_scalar_exc = None + exc = None + return exc + + def _get_arith_xfail_marker(self, opname, pa_dtype): + mark = None + + arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) + if ( - all_arithmetic_operators == "__rpow__" + opname == "__rpow__" and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) and not pa_version_under6p0 ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) + mark = pytest.mark.xfail( + reason=( + f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " + f"for {pa_dtype}" ) ) elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) + mark = pytest.mark.xfail( + raises=TypeError, + reason=( + f"{opname} not supported between" + f"pd.NA and {pa_dtype} Python scalar" + ), ) elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} + opname in {"__rtruediv__", "__rfloordiv__"} and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) and not pa_version_under6p0 ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", ) + + return mark + + def test_arith_series_with_scalar( + self, data, all_arithmetic_operators, request, monkeypatch + ): + pa_dtype = data.dtype.pyarrow_dtype + + if all_arithmetic_operators == "__rmod__" and ( + pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + ): + pytest.skip("Skip testing Python string formatting") + + self.series_scalar_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype + ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1129,61 +1166,19 @@ def test_arith_frame_with_scalar( ): pa_dtype = data.dtype.pyarrow_dtype - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") - and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") - and pa.types.is_temporal(pa_dtype) - ) if all_arithmetic_operators == "__rmod__" and ( pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): pytest.skip("Skip testing Python string formatting") - elif all_arithmetic_operators in { - "__mod__", - "__rmod__", - }: - self.frame_scalar_exc = NotImplementedError - elif arrow_temporal_supported: - self.frame_scalar_exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): - self.frame_scalar_exc = pa.ArrowNotImplementedError - else: - self.frame_scalar_exc = None - if ( - all_arithmetic_operators == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) - ) - ) - elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) - ) - elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) - ) + + self.frame_scalar_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype + ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1195,37 +1190,11 @@ def test_arith_series_with_array( ): pa_dtype = data.dtype.pyarrow_dtype - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") - and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") - and pa.types.is_temporal(pa_dtype) + self.series_array_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype ) - if all_arithmetic_operators in { - "__mod__", - "__rmod__", - }: - self.series_array_exc = NotImplementedError - elif arrow_temporal_supported: - self.series_array_exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): - self.series_array_exc = pa.ArrowNotImplementedError - else: - self.series_array_exc = None + if ( - all_arithmetic_operators == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) - ) - ) - elif ( all_arithmetic_operators in ( "__sub__", @@ -1243,32 +1212,17 @@ def test_arith_series_with_array( ), ) ) - elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) - ) - elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) - ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + op_name = all_arithmetic_operators ser = pd.Series(data) # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray # since ser.iloc[0] is a python scalar other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) + if pa.types.is_floating(pa_dtype) or ( pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__" ): @@ -1364,6 +1318,22 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): @pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) def test_quantile(data, interpolation, quantile, request): pa_dtype = data.dtype.pyarrow_dtype + + data = data.take([0, 0, 0]) + ser = pd.Series(data) + + if ( + pa.types.is_string(pa_dtype) + or pa.types.is_binary(pa_dtype) + or pa.types.is_boolean(pa_dtype) + ): + # For string, bytes, and bool, we don't *expect* to have quantile work + # Note this matches the non-pyarrow behavior + msg = r"Function 'quantile' has no kernel matching input types \(.*\)" + with pytest.raises(pa.ArrowNotImplementedError, match=msg): + ser.quantile(q=quantile, interpolation=interpolation) + return + if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)): request.node.add_marker( pytest.mark.xfail( @@ -1398,11 +1368,7 @@ def test_quantile(data, interpolation, quantile, request): ) def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): pa_dtype = data_for_grouping.dtype.pyarrow_dtype - if ( - pa.types.is_temporal(pa_dtype) - or pa.types.is_string(pa_dtype) - or pa.types.is_binary(pa_dtype) - ): + if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, From fa233a740526b3aa587af43b0e457e763a6a342d Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 11 Jan 2023 19:14:15 -0800 Subject: [PATCH 2/5] mypy, min_version fixups --- pandas/core/arrays/arrow/array.py | 5 +++-- pandas/tests/extension/test_arrow.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3f3a6f1d5f267..84fa0dd3cfaca 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1214,9 +1214,10 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra if pa.types.is_temporal(pa_type): nbits = pa_type.bit_width dtype = f"int{nbits}[pyarrow]" - obj = self.astype(dtype, copy=False) + obj = cast(ArrowExtensionArrayT, self.astype(dtype, copy=False)) result = obj._mode(dropna=dropna) - return result.astype(self.dtype, copy=False) + out = result.astype(self.dtype, copy=False) + return cast(ArrowExtensionArrayT, out) modes = pc.mode(self._data, pc.count_distinct(self._data).as_py()) values = modes.field(0) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index b389ed0b59f82..61c86759251e2 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1329,7 +1329,10 @@ def test_quantile(data, interpolation, quantile, request): ): # For string, bytes, and bool, we don't *expect* to have quantile work # Note this matches the non-pyarrow behavior - msg = r"Function 'quantile' has no kernel matching input types \(.*\)" + if pa_version_under7p0: + msg = r"Function quantile has no kernel matching input types \(.*\)" + else: + msg = r"Function 'quantile' has no kernel matching input types \(.*\)" with pytest.raises(pa.ArrowNotImplementedError, match=msg): ser.quantile(q=quantile, interpolation=interpolation) return From 72bed6f587a587bac5042ca345a1809995d319f9 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 14 Jan 2023 13:08:51 -0800 Subject: [PATCH 3/5] use pa cast --- pandas/core/arrays/arrow/array.py | 46 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f6b5e687aa225..3027f2866c52e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -657,12 +657,11 @@ def factorize( pa_type = self._data.type if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]")) - indices, uniques = arr.factorize(use_na_sentinel=use_na_sentinel) - uniques = uniques.astype(self.dtype) - return indices, uniques + data = self._data.cast(pa.int64()) + else: + data = self._data - encoded = self._data.dictionary_encode(null_encoding=null_encoding) + encoded = data.dictionary_encode(null_encoding=null_encoding) if encoded.length() == 0: indices = np.array([], dtype=np.intp) uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type)) @@ -674,6 +673,9 @@ def factorize( np.intp, copy=False ) uniques = type(self)(encoded.chunk(0).dictionary) + + if pa.types.is_duration(pa_type): + uniques = uniques.astype(self.dtype) return indices, uniques def reshape(self, *args, **kwargs): @@ -858,13 +860,20 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: ------- ArrowExtensionArray """ - if pa.types.is_duration(self._data.type): + pa_type = self._data.type + + if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArrayT, self.astype("int64[pyarrow]")) - result = arr.unique() - return cast(ArrowExtensionArrayT, result.astype(self.dtype)) + data = self._data.cast(pa.int64()) + else: + data = self._data - return type(self)(pc.unique(self._data)) + pa_result = pc.unique(data) + + if pa.types.is_duration(pa_type): + pa_result = pa_result.cast(pa_type) + + return type(self)(pa_result) def value_counts(self, dropna: bool = True) -> Series: """ @@ -883,27 +892,30 @@ def value_counts(self, dropna: bool = True) -> Series: -------- Series.value_counts """ - if pa.types.is_duration(self._data.type): + pa_type = self._data.type + if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]")) - result = arr.value_counts() - result.index = result.index.astype(self.dtype) - return result + data = self._data.cast(pa.int64()) + else: + data = self._data from pandas import ( Index, Series, ) - vc = self._data.value_counts() + vc = data.value_counts() values = vc.field(0) counts = vc.field(1) - if dropna and self._data.null_count > 0: + if dropna and data.null_count > 0: mask = values.is_valid() values = values.filter(mask) counts = counts.filter(mask) + if pa.types.is_duration(pa_type): + values = values.cast(pa_type) + # No missing values so we can adhere to the interface and return a numpy array. counts = np.array(counts) From 9cf7fe1680b2db7e3377b43ec1e8252685d424b1 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 16 Jan 2023 07:52:58 -0800 Subject: [PATCH 4/5] mypy fixup --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3027f2866c52e..09e24b9b1958e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -675,7 +675,7 @@ def factorize( uniques = type(self)(encoded.chunk(0).dictionary) if pa.types.is_duration(pa_type): - uniques = uniques.astype(self.dtype) + uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype)) return indices, uniques def reshape(self, *args, **kwargs): From 1a515a1b5f9fbf528d9f10d7456054e4ce6fa46e Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 17 Jan 2023 15:30:04 -0800 Subject: [PATCH 5/5] use cast instead of astype --- pandas/core/arrays/arrow/array.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 09e24b9b1958e..d7faf00113609 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1247,18 +1247,25 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra pa_type = self._data.type if pa.types.is_temporal(pa_type): nbits = pa_type.bit_width - dtype = f"int{nbits}[pyarrow]" - obj = cast(ArrowExtensionArrayT, self.astype(dtype, copy=False)) - result = obj._mode(dropna=dropna) - out = result.astype(self.dtype, copy=False) - return cast(ArrowExtensionArrayT, out) + if nbits == 32: + data = self._data.cast(pa.int32()) + elif nbits == 64: + data = self._data.cast(pa.int64()) + else: + raise NotImplementedError(pa_type) + else: + data = self._data - modes = pc.mode(self._data, pc.count_distinct(self._data).as_py()) + modes = pc.mode(data, pc.count_distinct(data).as_py()) values = modes.field(0) counts = modes.field(1) # counts sorted descending i.e counts[0] = max mask = pc.equal(counts, counts[0]) most_common = values.filter(mask) + + if pa.types.is_temporal(pa_type): + most_common = most_common.cast(pa_type) + return type(self)(most_common) def _maybe_convert_setitem_value(self, value):