Skip to content

Commit c3eca7e

Browse files
[ArrayManager] Remaining GroupBy tests (fix count, pass on libreduction for now) (#40050)
1 parent cff293b commit c3eca7e

File tree

11 files changed

+78
-9
lines changed

11 files changed

+78
-9
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ jobs:
157157
pytest pandas/tests/reductions/ --array-manager
158158
pytest pandas/tests/generic/test_generic.py --array-manager
159159
pytest pandas/tests/arithmetic/ --array-manager
160-
pytest pandas/tests/groupby/aggregate/ --array-manager
160+
pytest pandas/tests/groupby/ --array-manager
161161
pytest pandas/tests/reshape/merge --array-manager
162162
163163
# indexing subset (temporary since other tests don't pass yet)

pandas/core/groupby/generic.py

+6
Original file line numberDiff line numberDiff line change
@@ -1815,6 +1815,8 @@ def count(self) -> DataFrame:
18151815
ids, _, ngroups = self.grouper.group_info
18161816
mask = ids != -1
18171817

1818+
using_array_manager = isinstance(data, ArrayManager)
1819+
18181820
def hfunc(bvalues: ArrayLike) -> ArrayLike:
18191821
# TODO(2DEA): reshape would not be necessary with 2D EAs
18201822
if bvalues.ndim == 1:
@@ -1824,6 +1826,10 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
18241826
masked = mask & ~isna(bvalues)
18251827

18261828
counted = lib.count_level_2d(masked, labels=ids, max_bin=ngroups, axis=1)
1829+
if using_array_manager:
1830+
# count_level_2d return (1, N) array for single column
1831+
# -> extract 1D array
1832+
counted = counted[0, :]
18271833
return counted
18281834

18291835
new_mgr = data.grouped_reduce(hfunc)

pandas/core/groupby/ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
MultiIndex,
8585
ensure_index,
8686
)
87+
from pandas.core.internals import ArrayManager
8788
from pandas.core.series import Series
8889
from pandas.core.sorting import (
8990
compress_group_index,
@@ -214,6 +215,10 @@ def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
214215
# TODO: can we have a workaround for EAs backed by ndarray?
215216
pass
216217

218+
elif isinstance(sdata._mgr, ArrayManager):
219+
# TODO(ArrayManager) don't use fast_apply / libreduction.apply_frame_axis0
220+
# for now -> relies on BlockManager internals
221+
pass
217222
elif (
218223
com.get_callable_name(f) not in base.plotting_methods
219224
and isinstance(splitter, FrameSplitter)

pandas/core/internals/array_manager.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,30 @@ def grouped_reduce(self: T, func: Callable, ignore_failures: bool = False) -> T:
270270
-------
271271
ArrayManager
272272
"""
273-
# TODO ignore_failures
274-
result_arrays = [func(arr) for arr in self.arrays]
273+
result_arrays: List[np.ndarray] = []
274+
result_indices: List[int] = []
275+
276+
for i, arr in enumerate(self.arrays):
277+
try:
278+
res = func(arr)
279+
except (TypeError, NotImplementedError):
280+
if not ignore_failures:
281+
raise
282+
continue
283+
result_arrays.append(res)
284+
result_indices.append(i)
275285

276286
if len(result_arrays) == 0:
277287
index = Index([None]) # placeholder
278288
else:
279289
index = Index(range(result_arrays[0].shape[0]))
280290

281-
return type(self)(result_arrays, [index, self.items])
291+
if ignore_failures:
292+
columns = self.items[np.array(result_indices, dtype="int64")]
293+
else:
294+
columns = self.items
295+
296+
return type(self)(result_arrays, [index, columns])
282297

283298
def operate_blockwise(self, other: ArrayManager, array_op) -> ArrayManager:
284299
"""

pandas/tests/groupby/test_allowlist.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import pytest
1010

11+
import pandas.util._test_decorators as td
12+
1113
from pandas import (
1214
DataFrame,
1315
Index,
@@ -355,7 +357,8 @@ def test_groupby_function_rename(mframe):
355357
"cummax",
356358
"cummin",
357359
"cumprod",
358-
"describe",
360+
# TODO(ArrayManager) quantile
361+
pytest.param("describe", marks=td.skip_array_manager_not_yet_implemented),
359362
"rank",
360363
"quantile",
361364
"diff",

pandas/tests/groupby/test_apply.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
import pytest
99

10+
import pandas.util._test_decorators as td
11+
1012
import pandas as pd
1113
from pandas import (
1214
DataFrame,
@@ -84,6 +86,7 @@ def test_apply_trivial_fail():
8486
tm.assert_frame_equal(result, expected)
8587

8688

89+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fast_apply not used
8790
def test_fast_apply():
8891
# make sure that fast apply is correctly called
8992
# rather than raising any kind of error
@@ -213,6 +216,7 @@ def test_group_apply_once_per_group2(capsys):
213216
assert result == expected
214217

215218

219+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fast_apply not used
216220
@pytest.mark.xfail(reason="GH-34998")
217221
def test_apply_fast_slow_identical():
218222
# GH 31613
@@ -233,6 +237,7 @@ def fast(group):
233237
tm.assert_frame_equal(fast_df, slow_df)
234238

235239

240+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) fast_apply not used
236241
@pytest.mark.parametrize(
237242
"func",
238243
[
@@ -313,6 +318,7 @@ def test_groupby_as_index_apply(df):
313318
tm.assert_index_equal(res, ind)
314319

315320

321+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
316322
def test_apply_concat_preserve_names(three_group):
317323
grouped = three_group.groupby(["A", "B"])
318324

@@ -1003,9 +1009,10 @@ def test_apply_function_with_indexing_return_column():
10031009
tm.assert_frame_equal(result, expected)
10041010

10051011

1006-
@pytest.mark.xfail(reason="GH-34998")
1007-
def test_apply_with_timezones_aware():
1012+
def test_apply_with_timezones_aware(using_array_manager, request):
10081013
# GH: 27212
1014+
if not using_array_manager:
1015+
request.node.add_marker(pytest.mark.xfail(reason="GH-34998"))
10091016

10101017
dates = ["2001-01-01"] * 2 + ["2001-01-02"] * 2 + ["2001-01-03"] * 2
10111018
index_no_tz = pd.DatetimeIndex(dates)

pandas/tests/groupby/test_categorical.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
import pytest
55

6+
import pandas.util._test_decorators as td
7+
68
import pandas as pd
79
from pandas import (
810
Categorical,
@@ -81,6 +83,7 @@ def get_stats(group):
8183
assert result.index.names[0] == "C"
8284

8385

86+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
8487
def test_basic():
8588

8689
cats = Categorical(
@@ -276,7 +279,9 @@ def test_apply(ordered):
276279
tm.assert_series_equal(result, expected)
277280

278281

279-
def test_observed(observed):
282+
# TODO(ArrayManager) incorrect dtype for mean()
283+
@td.skip_array_manager_not_yet_implemented
284+
def test_observed(observed, using_array_manager):
280285
# multiple groupers, don't re-expand the output space
281286
# of the grouper
282287
# gh-14942 (implement)
@@ -535,6 +540,7 @@ def test_dataframe_categorical_ordered_observed_sort(ordered, observed, sort):
535540
assert False, msg
536541

537542

543+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
538544
def test_datetime():
539545
# GH9049: ensure backward compatibility
540546
levels = pd.date_range("2014-01-01", periods=4)
@@ -600,6 +606,7 @@ def test_categorical_index():
600606
tm.assert_frame_equal(result, expected)
601607

602608

609+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
603610
def test_describe_categorical_columns():
604611
# GH 11558
605612
cats = CategoricalIndex(
@@ -614,6 +621,7 @@ def test_describe_categorical_columns():
614621
tm.assert_categorical_equal(result.stack().columns.values, cats.values)
615622

616623

624+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
617625
def test_unstack_categorical():
618626
# GH11558 (example is taken from the original issue)
619627
df = DataFrame(

pandas/tests/groupby/test_function.py

+8
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def test_mad(self, gb, gni):
367367
result = gni.mad()
368368
tm.assert_frame_equal(result, expected)
369369

370+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
370371
def test_describe(self, df, gb, gni):
371372
# describe
372373
expected_index = Index([1, 3], name="A")
@@ -923,11 +924,13 @@ def test_is_monotonic_decreasing(in_vals, out_vals):
923924
# --------------------------------
924925

925926

927+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
926928
def test_apply_describe_bug(mframe):
927929
grouped = mframe.groupby(level="first")
928930
grouped.describe() # it works!
929931

930932

933+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
931934
def test_series_describe_multikey():
932935
ts = tm.makeTimeSeries()
933936
grouped = ts.groupby([lambda x: x.year, lambda x: x.month])
@@ -937,6 +940,7 @@ def test_series_describe_multikey():
937940
tm.assert_series_equal(result["min"], grouped.min(), check_names=False)
938941

939942

943+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
940944
def test_series_describe_single():
941945
ts = tm.makeTimeSeries()
942946
grouped = ts.groupby(lambda x: x.month)
@@ -951,6 +955,7 @@ def test_series_index_name(df):
951955
assert result.index.name == "A"
952956

953957

958+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
954959
def test_frame_describe_multikey(tsframe):
955960
grouped = tsframe.groupby([lambda x: x.year, lambda x: x.month])
956961
result = grouped.describe()
@@ -973,6 +978,7 @@ def test_frame_describe_multikey(tsframe):
973978
tm.assert_frame_equal(result, expected)
974979

975980

981+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
976982
def test_frame_describe_tupleindex():
977983

978984
# GH 14848 - regression from 0.19.0 to 0.19.1
@@ -992,6 +998,7 @@ def test_frame_describe_tupleindex():
992998
df2.groupby("key").describe()
993999

9941000

1001+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
9951002
def test_frame_describe_unstacked_format():
9961003
# GH 4792
9971004
prices = {
@@ -1018,6 +1025,7 @@ def test_frame_describe_unstacked_format():
10181025
tm.assert_frame_equal(result, expected)
10191026

10201027

1028+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
10211029
@pytest.mark.filterwarnings(
10221030
"ignore:"
10231031
"indexing past lexsort depth may impact performance:"

pandas/tests/groupby/test_groupby.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from pandas.compat import IS64
99
from pandas.errors import PerformanceWarning
10+
import pandas.util._test_decorators as td
1011

1112
import pandas as pd
1213
from pandas import (
@@ -210,6 +211,7 @@ def f(grp):
210211
tm.assert_series_equal(result, e)
211212

212213

214+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
213215
def test_pass_args_kwargs(ts, tsframe):
214216
def f(x, q=None, axis=0):
215217
return np.percentile(x, q, axis=axis)
@@ -364,6 +366,7 @@ def f3(x):
364366
df2.groupby("a").apply(f3)
365367

366368

369+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile
367370
def test_attr_wrapper(ts):
368371
grouped = ts.groupby(lambda x: x.weekday())
369372

pandas/tests/groupby/test_quantile.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import numpy as np
22
import pytest
33

4+
import pandas.util._test_decorators as td
5+
46
import pandas as pd
57
from pandas import (
68
DataFrame,
79
Index,
810
)
911
import pandas._testing as tm
1012

13+
# TODO(ArrayManager) quantile
14+
pytestmark = td.skip_array_manager_not_yet_implemented
15+
1116

1217
@pytest.mark.parametrize(
1318
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]

pandas/tests/groupby/transform/test_transform.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import pytest
66

7+
import pandas.util._test_decorators as td
8+
79
from pandas.core.dtypes.common import (
810
ensure_platform_int,
911
is_timedelta64_dtype,
@@ -161,8 +163,13 @@ def test_transform_broadcast(tsframe, ts):
161163
assert_fp_equal(res.xs(idx), agged[idx])
162164

163165

164-
def test_transform_axis_1(request, transformation_func):
166+
def test_transform_axis_1(request, transformation_func, using_array_manager):
165167
# GH 36308
168+
if using_array_manager and transformation_func == "pct_change":
169+
# TODO(ArrayManager) column-wise shift
170+
request.node.add_marker(
171+
pytest.mark.xfail(reason="ArrayManager: shift axis=1 not yet implemented")
172+
)
166173
warn = None
167174
if transformation_func == "tshift":
168175
warn = FutureWarning
@@ -183,6 +190,8 @@ def test_transform_axis_1(request, transformation_func):
183190
tm.assert_equal(result, expected)
184191

185192

193+
# TODO(ArrayManager) groupby().transform returns DataFrame backed by BlockManager
194+
@td.skip_array_manager_not_yet_implemented
186195
def test_transform_axis_ts(tsframe):
187196

188197
# make sure that we are setting the axes

0 commit comments

Comments
 (0)