From 26e90846fab555d02deade0dbc172eb4fab53835 Mon Sep 17 00:00:00 2001 From: Alex Papanicolaou Date: Thu, 3 Dec 2020 11:51:10 -0800 Subject: [PATCH] test groupby.indices for multiple groupby and mix of types Creates tests for GH26859 --- pandas/tests/groupby/test_groupby.py | 104 +++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 7c179a79513fa..070976ef5a8b1 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -7,6 +7,8 @@ from pandas.errors import PerformanceWarning +from pandas.core.dtypes.common import is_categorical_dtype, is_datetime64_any_dtype + import pandas as pd from pandas import ( DataFrame, @@ -361,6 +363,108 @@ def f3(x): df2.groupby("a").apply(f3) +def test_single_groupby_indices_output(): + cols = [ + "int", + "int_cat", + "float", + "float_cat", + "dt", + "dt_cat", + "dttz", + "dttz_cat", + "period", + "period_cat", + ] + + int_ = Series([1, 2, 3]) + dt_ = pd.to_datetime(["2018Q1", "2018Q2", "2018Q3"]) + dttz_ = dt_.tz_localize("Europe/Berlin") + df = DataFrame( + data={ + "int": int_, + "int_cat": int_.astype("category"), + "float": int_.astype("float"), + "float_cat": int_.astype("float").astype("category"), + "dt": dt_, + "dt_cat": dt_.astype("category"), + "dttz": dttz_, + "dttz_cat": dttz_.astype("category"), + "period": dt_.to_period("Q"), + "period_cat": dt_.to_period("Q").astype("category"), + }, + columns=cols, + ) + for col in df.columns: + col_vals = list(df[col].unique()) + + if is_datetime64_any_dtype(df[col]): + col_vals = [Timestamp(el) for el in col_vals] + + target = {key: np.array([i]) for i, key in enumerate(col_vals)} + + indices = df.groupby(col).indices + + assert set(target.keys()) == set(indices.keys()) + for key in target.keys(): + assert pd.core.dtypes.missing.array_equivalent(target[key], indices[key]) + + +def test_multiple_groupby_indices_output(): + cols = [ + "int", + "int_cat", + "float", + "float_cat", + "dt", + "dt_cat", + "dttz", + "dttz_cat", + "period", + "period_cat", + "value", + ] + + int_ = Series([1, 2, 3]) + dt_ = pd.to_datetime(["2018Q1", "2018Q2", "2018Q3"]) + dttz_ = dt_.tz_localize("Europe/Berlin") + df = DataFrame( + data={ + "int": int_, + "int_cat": int_.astype("category"), + "float": int_.astype("float"), + "float_cat": int_.astype("float").astype("category"), + "dt": dt_, + "dt_cat": dt_.astype("category"), + "dttz": dttz_, + "dttz_cat": dttz_.astype("category"), + "period": dt_.to_period("Q"), + "period_cat": dt_.to_period("Q").astype("category"), + "value": Series([1, 2, 3]), + }, + columns=cols, + ) + groupby_cols = cols[:-1] + col_vals = {col: list(df[col].unique()) for col in groupby_cols} + + for col in groupby_cols: + is_dt = is_datetime64_any_dtype(df[col]) + is_cat_dt = is_categorical_dtype(df[col]) and is_datetime64_any_dtype( + df[col].cat.categories + ) + if is_dt or is_cat_dt: + col_vals[col] = [Timestamp(el) for el in col_vals[col]] + + it = zip(*(col_vals[col] for col in groupby_cols)) + target = {key: np.array([i]) for i, key in enumerate(it)} + + indices = df.groupby(groupby_cols).indices + + assert set(target.keys()) == set(indices.keys()) + for key in target.keys(): + assert pd.core.dtypes.missing.array_equivalent(target[key], indices[key]) + + def test_attr_wrapper(ts): grouped = ts.groupby(lambda x: x.weekday())