Skip to content

Commit de728ad

Browse files
rhshadrachWillAyd
authored andcommitted
Fix extension tests
1 parent d625522 commit de728ad

File tree

4 files changed

+31
-32
lines changed

4 files changed

+31
-32
lines changed

pandas/tests/extension/base/accumulate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
1818
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
1919
try:
2020
alt = ser.astype("float64")
21-
except TypeError:
22-
# e.g. Period can't be cast to float64
21+
except (TypeError, ValueError):
22+
# e.g. Period can't be cast to float64 (TypeError)
23+
# String can't be cast to float64 (ValueError)
2324
alt = ser.astype(object)
2425

2526
result = getattr(ser, op_name)(skipna=skipna)

pandas/tests/extension/test_arrow.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
393393
# attribute "pyarrow_dtype"
394394
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
395395

396-
if (
397-
pa.types.is_string(pa_type)
398-
or pa.types.is_binary(pa_type)
399-
or pa.types.is_decimal(pa_type)
400-
):
396+
if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type):
401397
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
402398
return False
399+
elif pa.types.is_string(pa_type):
400+
if op_name == "cumprod":
401+
return False
403402
elif pa.types.is_boolean(pa_type):
404403
if op_name in ["cumprod", "cummax", "cummin"]:
405404
return False
@@ -414,6 +413,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
414413
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
415414
pa_type = data.dtype.pyarrow_dtype
416415
op_name = all_numeric_accumulations
416+
417+
if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]:
418+
# https://github.com/pandas-dev/pandas/pull/60633
419+
# Doesn't fit test structure, tested in series/test_cumulative.py instead.
420+
return
421+
417422
ser = pd.Series(data)
418423

419424
if not self._supports_accumulation(ser, op_name):

pandas/tests/extension/test_string.py

+7
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,13 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
192192
and op_name in ("any", "all")
193193
)
194194

195+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
196+
return ser.dtype.storage == "pyarrow" and op_name in [
197+
"cummin",
198+
"cummax",
199+
"cumsum",
200+
]
201+
195202
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
196203
dtype = cast(StringDtype, tm.get_dtype(obj))
197204
if op_name in ["__add__", "__radd__"]:

pandas/tests/series/test_cumulative.py

+11-25
Original file line numberDiff line numberDiff line change
@@ -230,31 +230,19 @@ def test_cumprod_timedelta(self):
230230
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
231231
ser.cumprod()
232232

233-
@pytest.mark.parametrize(
234-
"data, skipna, expected_data",
235-
[
236-
([], True, []),
237-
([], False, []),
238-
(["x", "z", "y"], True, ["x", "xz", "xzy"]),
239-
(["x", "z", "y"], False, ["x", "xz", "xzy"]),
240-
(["x", pd.NA, "y"], True, ["x", "x", "xy"]),
241-
(["x", pd.NA, "y"], False, ["x", pd.NA, pd.NA]),
242-
([pd.NA, pd.NA, pd.NA], True, ["", "", ""]),
243-
([pd.NA, pd.NA, pd.NA], False, [pd.NA, pd.NA, pd.NA]),
244-
],
245-
)
246-
def test_cumsum_pyarrow_strings(
247-
self, pyarrow_string_dtype, data, skipna, expected_data
248-
):
249-
# https://github.com/pandas-dev/pandas/pull/60633
250-
ser = pd.Series(data, dtype=pyarrow_string_dtype)
251-
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
252-
result = ser.cumsum(skipna=skipna)
253-
tm.assert_series_equal(result, expected)
254-
255233
@pytest.mark.parametrize(
256234
"data, op, skipna, expected_data",
257235
[
236+
([], "cumsum", True, []),
237+
([], "cumsum", False, []),
238+
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]),
239+
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]),
240+
(["x", pd.NA, "y"], "cumsum", True, ["x", "x", "xy"]),
241+
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]),
242+
([pd.NA, "x", "y"], "cumsum", True, ["", "x", "xy"]),
243+
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
244+
([pd.NA, pd.NA, pd.NA], "cumsum", True, ["", "", ""]),
245+
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
258246
([], "cummin", True, []),
259247
([], "cummin", False, []),
260248
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
@@ -277,13 +265,11 @@ def test_cumsum_pyarrow_strings(
277265
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
278266
],
279267
)
280-
def test_cummin_cummax_pyarrow_strings(
268+
def test_cum_methods_pyarrow_strings(
281269
self, pyarrow_string_dtype, data, op, skipna, expected_data
282270
):
283271
# https://github.com/pandas-dev/pandas/pull/60633
284272
ser = pd.Series(data, dtype=pyarrow_string_dtype)
285-
if expected_data is None:
286-
expected_data = ser.dtype.na_value
287273
method = getattr(ser, op)
288274
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
289275
result = method(skipna=skipna)

0 commit comments

Comments
 (0)