|
10 | 10 | classes (if they are relevant for the extension interface for all dtypes), or
|
11 | 11 | be added to the array-specific tests in `pandas/tests/arrays/`.
|
12 | 12 | """
|
13 |
| - |
14 | 13 | from datetime import (
|
15 | 14 | date,
|
16 | 15 | datetime,
|
|
24 | 23 | from pandas.compat import (
|
25 | 24 | pa_version_under2p0,
|
26 | 25 | pa_version_under3p0,
|
| 26 | + pa_version_under4p0, |
27 | 27 | pa_version_under6p0,
|
28 | 28 | pa_version_under8p0,
|
| 29 | + pa_version_under9p0, |
29 | 30 | )
|
30 | 31 |
|
31 | 32 | import pandas as pd
|
@@ -1993,3 +1994,72 @@ def test_compare_array(self, data, comparison_op, na_value, request):
|
1993 | 1994 | def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
|
1994 | 1995 | with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
|
1995 | 1996 | ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
|
| 1997 | + |
| 1998 | + |
| 1999 | +@pytest.mark.xfail( |
| 2000 | + pa_version_under4p0, |
| 2001 | + raises=NotImplementedError, |
| 2002 | + reason="quantile only supported for pyarrow version >= 4.0", |
| 2003 | +) |
| 2004 | +@pytest.mark.parametrize( |
| 2005 | + "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] |
| 2006 | +) |
| 2007 | +@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) |
| 2008 | +def test_quantile(data, interpolation, quantile, request): |
| 2009 | + pa_dtype = data.dtype.pyarrow_dtype |
| 2010 | + if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)): |
| 2011 | + request.node.add_marker( |
| 2012 | + pytest.mark.xfail( |
| 2013 | + raises=pa.ArrowNotImplementedError, |
| 2014 | + reason=f"quantile not supported by pyarrow for {pa_dtype}", |
| 2015 | + ) |
| 2016 | + ) |
| 2017 | + data = data.take([0, 0, 0]) |
| 2018 | + ser = pd.Series(data) |
| 2019 | + result = ser.quantile(q=quantile, interpolation=interpolation) |
| 2020 | + if quantile == 0.5: |
| 2021 | + assert result == data[0] |
| 2022 | + else: |
| 2023 | + # Just check the values |
| 2024 | + result = result.astype("float64[pyarrow]") |
| 2025 | + expected = pd.Series( |
| 2026 | + data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5] |
| 2027 | + ) |
| 2028 | + tm.assert_series_equal(result, expected) |
| 2029 | + |
| 2030 | + |
| 2031 | +@pytest.mark.xfail( |
| 2032 | + pa_version_under6p0, |
| 2033 | + raises=NotImplementedError, |
| 2034 | + reason="mode only supported for pyarrow version >= 6.0", |
| 2035 | +) |
| 2036 | +@pytest.mark.parametrize("dropna", [True, False]) |
| 2037 | +@pytest.mark.parametrize( |
| 2038 | + "take_idx, exp_idx", |
| 2039 | + [[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]], |
| 2040 | + ids=["multi_mode", "single_mode"], |
| 2041 | +) |
| 2042 | +def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): |
| 2043 | + pa_dtype = data_for_grouping.dtype.pyarrow_dtype |
| 2044 | + if pa.types.is_temporal(pa_dtype): |
| 2045 | + request.node.add_marker( |
| 2046 | + pytest.mark.xfail( |
| 2047 | + raises=pa.ArrowNotImplementedError, |
| 2048 | + reason=f"mode not supported by pyarrow for {pa_dtype}", |
| 2049 | + ) |
| 2050 | + ) |
| 2051 | + elif ( |
| 2052 | + pa.types.is_boolean(pa_dtype) |
| 2053 | + and "multi_mode" in request.node.nodeid |
| 2054 | + and pa_version_under9p0 |
| 2055 | + ): |
| 2056 | + request.node.add_marker( |
| 2057 | + pytest.mark.xfail( |
| 2058 | + reason="https://issues.apache.org/jira/browse/ARROW-17096", |
| 2059 | + ) |
| 2060 | + ) |
| 2061 | + data = data_for_grouping.take(take_idx) |
| 2062 | + ser = pd.Series(data) |
| 2063 | + result = ser.mode(dropna=dropna) |
| 2064 | + expected = pd.Series(data_for_grouping.take(exp_idx)) |
| 2065 | + tm.assert_series_equal(result, expected) |
0 commit comments