|
24 | 24 | from pandas.compat import (
|
25 | 25 | pa_version_under2p0,
|
26 | 26 | pa_version_under3p0,
|
| 27 | + pa_version_under6p0, |
27 | 28 | pa_version_under8p0,
|
28 | 29 | )
|
29 | 30 |
|
@@ -303,6 +304,95 @@ def test_loc_iloc_frame_single_dtype(self, request, using_array_manager, data):
|
303 | 304 | super().test_loc_iloc_frame_single_dtype(data)
|
304 | 305 |
|
305 | 306 |
|
| 307 | +class TestBaseNumericReduce(base.BaseNumericReduceTests): |
| 308 | + def check_reduce(self, ser, op_name, skipna): |
| 309 | + pa_dtype = ser.dtype.pyarrow_dtype |
| 310 | + result = getattr(ser, op_name)(skipna=skipna) |
| 311 | + if pa.types.is_boolean(pa_dtype): |
| 312 | + # Can't convert if ser contains NA |
| 313 | + pytest.skip( |
| 314 | + "pandas boolean data with NA does not fully support all reductions" |
| 315 | + ) |
| 316 | + elif pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): |
| 317 | + ser = ser.astype("Float64") |
| 318 | + expected = getattr(ser, op_name)(skipna=skipna) |
| 319 | + tm.assert_almost_equal(result, expected) |
| 320 | + |
| 321 | + @pytest.mark.parametrize("skipna", [True, False]) |
| 322 | + def test_reduce_series(self, data, all_numeric_reductions, skipna, request): |
| 323 | + pa_dtype = data.dtype.pyarrow_dtype |
| 324 | + xfail_mark = pytest.mark.xfail( |
| 325 | + raises=TypeError, |
| 326 | + reason=( |
| 327 | + f"{all_numeric_reductions} is not implemented in " |
| 328 | + f"pyarrow={pa.__version__} for {pa_dtype}" |
| 329 | + ), |
| 330 | + ) |
| 331 | + if all_numeric_reductions in {"skew", "kurt"}: |
| 332 | + request.node.add_marker(xfail_mark) |
| 333 | + elif ( |
| 334 | + all_numeric_reductions in {"median", "var", "std", "prod", "max", "min"} |
| 335 | + and pa_version_under6p0 |
| 336 | + ): |
| 337 | + request.node.add_marker(xfail_mark) |
| 338 | + elif all_numeric_reductions in {"sum", "mean"} and pa_version_under2p0: |
| 339 | + request.node.add_marker(xfail_mark) |
| 340 | + elif ( |
| 341 | + all_numeric_reductions in {"sum", "mean"} |
| 342 | + and skipna is False |
| 343 | + and pa_version_under6p0 |
| 344 | + and (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)) |
| 345 | + ): |
| 346 | + request.node.add_marker( |
| 347 | + pytest.mark.xfail( |
| 348 | + raises=AssertionError, |
| 349 | + reason=( |
| 350 | + f"{all_numeric_reductions} with skip_nulls={skipna} did not " |
| 351 | + f"return NA for {pa_dtype} with pyarrow={pa.__version__}" |
| 352 | + ), |
| 353 | + ) |
| 354 | + ) |
| 355 | + elif not ( |
| 356 | + pa.types.is_integer(pa_dtype) |
| 357 | + or pa.types.is_floating(pa_dtype) |
| 358 | + or pa.types.is_boolean(pa_dtype) |
| 359 | + ) and not ( |
| 360 | + all_numeric_reductions in {"min", "max"} |
| 361 | + and (pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype)) |
| 362 | + ): |
| 363 | + request.node.add_marker(xfail_mark) |
| 364 | + elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in { |
| 365 | + "std", |
| 366 | + "var", |
| 367 | + "median", |
| 368 | + }: |
| 369 | + request.node.add_marker(xfail_mark) |
| 370 | + super().test_reduce_series(data, all_numeric_reductions, skipna) |
| 371 | + |
| 372 | + |
| 373 | +class TestBaseBooleanReduce(base.BaseBooleanReduceTests): |
| 374 | + @pytest.mark.parametrize("skipna", [True, False]) |
| 375 | + def test_reduce_series( |
| 376 | + self, data, all_boolean_reductions, skipna, na_value, request |
| 377 | + ): |
| 378 | + pa_dtype = data.dtype.pyarrow_dtype |
| 379 | + xfail_mark = pytest.mark.xfail( |
| 380 | + raises=TypeError, |
| 381 | + reason=( |
| 382 | + f"{all_boolean_reductions} is not implemented in " |
| 383 | + f"pyarrow={pa.__version__} for {pa_dtype}" |
| 384 | + ), |
| 385 | + ) |
| 386 | + if not pa.types.is_boolean(pa_dtype): |
| 387 | + request.node.add_marker(xfail_mark) |
| 388 | + elif pa_version_under3p0: |
| 389 | + request.node.add_marker(xfail_mark) |
| 390 | + op_name = all_boolean_reductions |
| 391 | + s = pd.Series(data) |
| 392 | + result = getattr(s, op_name)(skipna=skipna) |
| 393 | + assert result is (op_name == "any") |
| 394 | + |
| 395 | + |
306 | 396 | class TestBaseGroupby(base.BaseGroupbyTests):
|
307 | 397 | def test_groupby_agg_extension(self, data_for_grouping, request):
|
308 | 398 | tz = getattr(data_for_grouping.dtype.pyarrow_dtype, "tz", None)
|
|
0 commit comments