|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
7 | 7 |
|
| 8 | +from pandas.core.dtypes.common import is_float_dtype |
| 9 | + |
8 | 10 | import pandas as pd
|
9 | 11 | from pandas import DataFrame, Index, NaT, Series, Timedelta, Timestamp, bdate_range
|
10 | 12 | import pandas._testing as tm
|
@@ -312,3 +314,69 @@ def test_cython_agg_nullable_int(op_name):
|
312 | 314 | # so for now just checking the values by casting to float
|
313 | 315 | result = result.astype("float64")
|
314 | 316 | tm.assert_series_equal(result, expected)
|
| 317 | + |
| 318 | + |
| 319 | +@pytest.mark.parametrize("with_na", [True, False]) |
| 320 | +@pytest.mark.parametrize( |
| 321 | + "op_name, action", |
| 322 | + [ |
| 323 | + # ("count", "always_int"), |
| 324 | + ("sum", "large_int"), |
| 325 | + # ("std", "always_float"), |
| 326 | + ("var", "always_float"), |
| 327 | + # ("sem", "always_float"), |
| 328 | + ("mean", "always_float"), |
| 329 | + ("median", "always_float"), |
| 330 | + ("prod", "large_int"), |
| 331 | + ("min", "preserve"), |
| 332 | + ("max", "preserve"), |
| 333 | + ("first", "preserve"), |
| 334 | + ("last", "preserve"), |
| 335 | + ], |
| 336 | +) |
| 337 | +@pytest.mark.parametrize( |
| 338 | + "data", |
| 339 | + [ |
| 340 | + pd.array([1, 2, 3, 4], dtype="Int64"), |
| 341 | + pd.array([1, 2, 3, 4], dtype="Int8"), |
| 342 | + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"), |
| 343 | + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"), |
| 344 | + pd.array([True, True, False, False], dtype="boolean"), |
| 345 | + ], |
| 346 | +) |
| 347 | +def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na): |
| 348 | + if with_na: |
| 349 | + data[3] = pd.NA |
| 350 | + |
| 351 | + df = DataFrame({"key": ["a", "a", "b", "b"], "col": data}) |
| 352 | + grouped = df.groupby("key") |
| 353 | + |
| 354 | + if action == "always_int": |
| 355 | + # always Int64 |
| 356 | + expected_dtype = pd.Int64Dtype() |
| 357 | + elif action == "large_int": |
| 358 | + # for any int/bool use Int64, for float preserve dtype |
| 359 | + if is_float_dtype(data.dtype): |
| 360 | + expected_dtype = data.dtype |
| 361 | + else: |
| 362 | + expected_dtype = pd.Int64Dtype() |
| 363 | + elif action == "always_float": |
| 364 | + # for any int/bool use Float64, for float preserve dtype |
| 365 | + if is_float_dtype(data.dtype): |
| 366 | + expected_dtype = data.dtype |
| 367 | + else: |
| 368 | + expected_dtype = pd.Float64Dtype() |
| 369 | + elif action == "preserve": |
| 370 | + expected_dtype = data.dtype |
| 371 | + |
| 372 | + result = getattr(grouped, op_name)() |
| 373 | + assert result["col"].dtype == expected_dtype |
| 374 | + |
| 375 | + result = grouped.aggregate(op_name) |
| 376 | + assert result["col"].dtype == expected_dtype |
| 377 | + |
| 378 | + result = getattr(grouped["col"], op_name)() |
| 379 | + assert result.dtype == expected_dtype |
| 380 | + |
| 381 | + result = grouped["col"].aggregate(op_name) |
| 382 | + assert result.dtype == expected_dtype |
0 commit comments