Skip to content

Commit 26e9084

Browse files
committed
test groupby.indices for multiple groupby and mix of types
Creates tests for GH26859
1 parent 5011a37 commit 26e9084

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

pandas/tests/groupby/test_groupby.py

+104
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pandas.errors import PerformanceWarning
99

10+
from pandas.core.dtypes.common import is_categorical_dtype, is_datetime64_any_dtype
11+
1012
import pandas as pd
1113
from pandas import (
1214
DataFrame,
@@ -361,6 +363,108 @@ def f3(x):
361363
df2.groupby("a").apply(f3)
362364

363365

366+
def test_single_groupby_indices_output():
367+
cols = [
368+
"int",
369+
"int_cat",
370+
"float",
371+
"float_cat",
372+
"dt",
373+
"dt_cat",
374+
"dttz",
375+
"dttz_cat",
376+
"period",
377+
"period_cat",
378+
]
379+
380+
int_ = Series([1, 2, 3])
381+
dt_ = pd.to_datetime(["2018Q1", "2018Q2", "2018Q3"])
382+
dttz_ = dt_.tz_localize("Europe/Berlin")
383+
df = DataFrame(
384+
data={
385+
"int": int_,
386+
"int_cat": int_.astype("category"),
387+
"float": int_.astype("float"),
388+
"float_cat": int_.astype("float").astype("category"),
389+
"dt": dt_,
390+
"dt_cat": dt_.astype("category"),
391+
"dttz": dttz_,
392+
"dttz_cat": dttz_.astype("category"),
393+
"period": dt_.to_period("Q"),
394+
"period_cat": dt_.to_period("Q").astype("category"),
395+
},
396+
columns=cols,
397+
)
398+
for col in df.columns:
399+
col_vals = list(df[col].unique())
400+
401+
if is_datetime64_any_dtype(df[col]):
402+
col_vals = [Timestamp(el) for el in col_vals]
403+
404+
target = {key: np.array([i]) for i, key in enumerate(col_vals)}
405+
406+
indices = df.groupby(col).indices
407+
408+
assert set(target.keys()) == set(indices.keys())
409+
for key in target.keys():
410+
assert pd.core.dtypes.missing.array_equivalent(target[key], indices[key])
411+
412+
413+
def test_multiple_groupby_indices_output():
414+
cols = [
415+
"int",
416+
"int_cat",
417+
"float",
418+
"float_cat",
419+
"dt",
420+
"dt_cat",
421+
"dttz",
422+
"dttz_cat",
423+
"period",
424+
"period_cat",
425+
"value",
426+
]
427+
428+
int_ = Series([1, 2, 3])
429+
dt_ = pd.to_datetime(["2018Q1", "2018Q2", "2018Q3"])
430+
dttz_ = dt_.tz_localize("Europe/Berlin")
431+
df = DataFrame(
432+
data={
433+
"int": int_,
434+
"int_cat": int_.astype("category"),
435+
"float": int_.astype("float"),
436+
"float_cat": int_.astype("float").astype("category"),
437+
"dt": dt_,
438+
"dt_cat": dt_.astype("category"),
439+
"dttz": dttz_,
440+
"dttz_cat": dttz_.astype("category"),
441+
"period": dt_.to_period("Q"),
442+
"period_cat": dt_.to_period("Q").astype("category"),
443+
"value": Series([1, 2, 3]),
444+
},
445+
columns=cols,
446+
)
447+
groupby_cols = cols[:-1]
448+
col_vals = {col: list(df[col].unique()) for col in groupby_cols}
449+
450+
for col in groupby_cols:
451+
is_dt = is_datetime64_any_dtype(df[col])
452+
is_cat_dt = is_categorical_dtype(df[col]) and is_datetime64_any_dtype(
453+
df[col].cat.categories
454+
)
455+
if is_dt or is_cat_dt:
456+
col_vals[col] = [Timestamp(el) for el in col_vals[col]]
457+
458+
it = zip(*(col_vals[col] for col in groupby_cols))
459+
target = {key: np.array([i]) for i, key in enumerate(it)}
460+
461+
indices = df.groupby(groupby_cols).indices
462+
463+
assert set(target.keys()) == set(indices.keys())
464+
for key in target.keys():
465+
assert pd.core.dtypes.missing.array_equivalent(target[key], indices[key])
466+
467+
364468
def test_attr_wrapper(ts):
365469
grouped = ts.groupby(lambda x: x.weekday())
366470

0 commit comments

Comments
 (0)