Skip to content

Commit b0d29d6

Browse files
committed
test groupby.indices for multiple groupby and mix of types
Creates tests for GH26859
1 parent 142ca08 commit b0d29d6

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

pandas/tests/groupby/test_groupby.py

+115
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pandas.errors import PerformanceWarning
99

10+
from pandas.core.dtypes.common import is_categorical_dtype, is_datetime64_any_dtype
11+
1012
import pandas as pd
1113
from pandas import (
1214
DataFrame,
@@ -361,6 +363,119 @@ def f3(x):
361363
df2.groupby("a").apply(f3)
362364

363365

366+
def test_groupby_indices_error():
367+
# GH 26860
368+
# Test if DataFrame Groupby builds gb.indices
369+
dt = pd.to_datetime(["2018-01-01", "2018-02-01", "2018-03-01"])
370+
df = DataFrame(
371+
{
372+
"a": Series(list("abc")),
373+
"b": Series(dt, dtype="category"),
374+
"c": pd.Categorical.from_codes([-1, 0, 1], categories=[0, 1]),
375+
}
376+
)
377+
378+
df.groupby(["a", "b"]).indices
379+
380+
381+
@pytest.mark.parametrize(
382+
"gb_cols",
383+
[
384+
"int_series",
385+
"int_series_cat",
386+
"float_series",
387+
"float_series_cat",
388+
"dt_series",
389+
"dt_series_cat",
390+
"period_series",
391+
"period_series_cat",
392+
[
393+
"int_series",
394+
"int_series_cat",
395+
"float_series",
396+
"float_series_cat",
397+
"dt_series",
398+
"dt_series_cat",
399+
"period_series",
400+
"period_series_cat",
401+
],
402+
],
403+
)
404+
def test_groupby_indices_output(gb_cols):
405+
# GH 26860
406+
# Test if DataFrame Groupby builds gb.indices correctly.
407+
if isinstance(gb_cols, str):
408+
gb_cols = [gb_cols]
409+
410+
cols = [
411+
"int_series",
412+
"int_series_cat",
413+
"float_series",
414+
"float_series_cat",
415+
"dt_series",
416+
"dt_series_cat",
417+
"dttz_series",
418+
"dttz_series_cat",
419+
"period_series",
420+
"period_series_cat",
421+
]
422+
423+
int_series = Series([1, 2, 3])
424+
dt_series = pd.to_datetime(["2018Q1", "2018Q2", "2018Q3"])
425+
dttz_series = dt_series.tz_localize("Europe/Berlin")
426+
df = DataFrame(
427+
data={
428+
"int_series": int_series,
429+
"int_series_cat": int_series.astype("category"),
430+
"float_series": int_series.astype("float"),
431+
"float_series_cat": int_series.astype("float").astype("category"),
432+
"dt_series": dt_series,
433+
"dt_series_cat": dt_series.astype("category"),
434+
"dttz_series": dttz_series,
435+
"dttz_series_cat": dttz_series.astype("category"),
436+
"period_series": dt_series.to_period("Q"),
437+
"period_series_cat": dt_series.to_period("Q").astype("category"),
438+
},
439+
columns=cols,
440+
)
441+
442+
def dt_to_ts(elems):
443+
return [Timestamp(el) for el in elems]
444+
445+
def ts_to_dt(elems):
446+
return [el.to_datetime64() for el in elems]
447+
448+
num_gb_cols = len(gb_cols)
449+
450+
if num_gb_cols == 1:
451+
s = df[gb_cols[0]]
452+
col_vals = list(s.unique())
453+
454+
if is_datetime64_any_dtype(s):
455+
col_vals = dt_to_ts(col_vals)
456+
457+
target = {key: np.array([i]) for i, key in enumerate(col_vals)}
458+
else:
459+
col_vals = {col: list(df[col].unique()) for col in gb_cols}
460+
461+
for col in gb_cols:
462+
is_dt = is_datetime64_any_dtype(df[col])
463+
is_cat_dt = is_categorical_dtype(df[col]) and is_datetime64_any_dtype(
464+
df[col].cat.categories
465+
)
466+
if is_dt or is_cat_dt:
467+
col_vals[col] = dt_to_ts(col_vals[col])
468+
469+
it = zip(*(col_vals[col] for col in gb_cols))
470+
target = {key: np.array([i]) for i, key in enumerate(it)}
471+
472+
indices = df.groupby(gb_cols).indices
473+
474+
assert set(target.keys()) == set(indices.keys())
475+
for key in target.keys():
476+
assert pd.core.dtypes.missing.array_equivalent(target[key], indices[key])
477+
478+
364479
def test_attr_wrapper(ts):
365480
grouped = ts.groupby(lambda x: x.weekday())
366481

0 commit comments

Comments
 (0)