Skip to content

Commit f717a7e

Browse files
committed
tests for dataframe.groupby with 2 Categoricals
1 parent a347e76 commit f717a7e

File tree

1 file changed

+104
-31
lines changed

1 file changed

+104
-31
lines changed

pandas/tests/groupby/test_categorical.py

+104-31
Original file line numberDiff line numberDiff line change
@@ -1259,11 +1259,100 @@ def test_get_nonexistent_category():
12591259
)
12601260

12611261

1262+
def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(
1263+
reduction_func:str):
1264+
1265+
if reduction_func == 'ngroup':
1266+
pytest.skip("ngroup does not return the Categories on the index")
1267+
1268+
res, unobserved_cats = _dataframe_groupby_on_2_categoricals(
1269+
reduction_func, observed=True)
1270+
1271+
for cat in unobserved_cats:
1272+
assert cat not in res.index
1273+
1274+
1275+
def _dataframe_groupby_on_2_categoricals(reduction_func:str, observed:bool):
1276+
1277+
df = pd.DataFrame({
1278+
"cat_1": pd.Categorical(list("AABB"), categories=list("ABC")),
1279+
"cat_2": pd.Categorical(list("1111"), categories=list("12")),
1280+
"value": [.1, .1, .1, .1]
1281+
})
1282+
unobserved_cats = [
1283+
('A', '2'),
1284+
('B', '2'),
1285+
('C', '1'),
1286+
('C', '2')
1287+
]
1288+
1289+
df_grp = df.groupby(['cat_1', 'cat_2'], observed=observed)
1290+
1291+
args = {
1292+
'nth' : [0],
1293+
'corrwith' : [df]
1294+
}.get(reduction_func, [])
1295+
res = getattr(df_grp, reduction_func)(*args)
1296+
1297+
return res, unobserved_cats
1298+
1299+
1300+
_results_for_groupbys_with_missing_categories = dict([
1301+
("all", np.NaN),
1302+
("any", np.NaN),
1303+
("count", 0),
1304+
("corrwith", np.NaN),
1305+
("first", np.NaN),
1306+
("idxmax", np.NaN),
1307+
("idxmin", np.NaN),
1308+
("last", np.NaN),
1309+
("mad", np.NaN),
1310+
("max", np.NaN),
1311+
("mean", np.NaN),
1312+
("median", np.NaN),
1313+
("min", np.NaN),
1314+
("nth", np.NaN),
1315+
("nunique", 0),
1316+
("prod", np.NaN),
1317+
("quantile", np.NaN),
1318+
("sem", np.NaN),
1319+
("size", 0),
1320+
("skew", np.NaN),
1321+
("std", np.NaN),
1322+
("sum", np.NaN),
1323+
("var", np.NaN),
1324+
])
1325+
1326+
1327+
@pytest.mark.parametrize('observed', [False, None])
1328+
def test_dataframe_groupby_on_2_categoricals_when_observed_is_false(
1329+
reduction_func:str, observed:bool, request):
1330+
1331+
if reduction_func == 'ngroup':
1332+
pytest.skip("ngroup does not return the Categories on the index")
1333+
1334+
if reduction_func == 'count':
1335+
mark = pytest.mark.xfail(
1336+
reason=("DataFrameGroupBy.count returns np.NaN for missing "
1337+
"categories, when it should return 0"))
1338+
request.node.add_marker(mark)
1339+
1340+
res, unobserved_cats = _dataframe_groupby_on_2_categoricals(
1341+
reduction_func, observed)
1342+
1343+
expected = _results_for_groupbys_with_missing_categories[reduction_func]
1344+
1345+
if expected is np.nan:
1346+
assert res.loc[unobserved_cats].isnull().all().all()
1347+
else:
1348+
assert (res.loc[unobserved_cats] == expected).all().all()
1349+
1350+
1351+
12621352
def test_series_groupby_on_2_categoricals_unobserved(
12631353
reduction_func: str, observed: bool, request
12641354
):
12651355
# GH 17605
1266-
12671356
if reduction_func == "ngroup":
12681357
pytest.skip("ngroup is not truly a reduction")
12691358

@@ -1289,36 +1378,18 @@ def test_series_groupby_on_2_categoricals_unobserved(
12891378
assert len(result) == expected_length
12901379

12911380

1292-
@pytest.mark.parametrize(
1293-
"func, zero_or_nan",
1294-
[
1295-
("all", np.NaN),
1296-
("any", np.NaN),
1297-
("count", 0),
1298-
("first", np.NaN),
1299-
("idxmax", np.NaN),
1300-
("idxmin", np.NaN),
1301-
("last", np.NaN),
1302-
("mad", np.NaN),
1303-
("max", np.NaN),
1304-
("mean", np.NaN),
1305-
("median", np.NaN),
1306-
("min", np.NaN),
1307-
("nth", np.NaN),
1308-
("nunique", 0),
1309-
("prod", np.NaN),
1310-
("quantile", np.NaN),
1311-
("sem", np.NaN),
1312-
("size", 0),
1313-
("skew", np.NaN),
1314-
("std", np.NaN),
1315-
("sum", np.NaN),
1316-
("var", np.NaN),
1317-
],
1318-
)
1319-
def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(func, zero_or_nan):
1381+
def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(
1382+
reduction_func:str, request):
13201383
# GH 17605
13211384
# Tests whether the unobserved categories in the result contain 0 or NaN
1385+
1386+
if reduction_func == "ngroup":
1387+
pytest.skip("ngroup is not truly a reduction")
1388+
1389+
if reduction_func == "corrwith": # GH 32293
1390+
mark = pytest.mark.xfail(reason="TODO: implemented SeriesGroupBy.corrwith")
1391+
request.node.add_marker(mark)
1392+
13221393
df = pd.DataFrame(
13231394
{
13241395
"cat_1": pd.Categorical(list("AABB"), categories=list("ABC")),
@@ -1327,11 +1398,13 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(func, zero_o
13271398
}
13281399
)
13291400
unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")]
1330-
args = {"nth": [0]}.get(func, [])
1401+
args = {"nth": [0]}.get(reduction_func, [])
13311402

13321403
series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"]
1333-
agg = getattr(series_groupby, func)
1404+
agg = getattr(series_groupby, reduction_func)
13341405
result = agg(*args)
1406+
1407+
zero_or_nan = _results_for_groupbys_with_missing_categories[reduction_func]
13351408

13361409
for idx in unobserved:
13371410
val = result.loc[idx]

0 commit comments

Comments
 (0)