Skip to content

BUG: groupby().agg fails on categorical column #31470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7e461a1
remove \n from docstring
charlesdong1991 Dec 3, 2018
1314059
fix conflicts
charlesdong1991 Jan 19, 2019
8bcb313
Merge remote-tracking branch 'upstream/master'
charlesdong1991 Jul 30, 2019
24c3ede
Merge remote-tracking branch 'upstream/master'
charlesdong1991 Jan 14, 2020
dea38f2
fix issue 17038
charlesdong1991 Jan 14, 2020
cd9e7ac
revert change
charlesdong1991 Jan 14, 2020
e5e912b
revert change
charlesdong1991 Jan 14, 2020
97f266f
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Jan 30, 2020
93ebadb
try fix
charlesdong1991 Jan 30, 2020
3520b95
upload test
charlesdong1991 Jan 30, 2020
32cc744
linting
charlesdong1991 Jan 30, 2020
9f936cc
broader concept
charlesdong1991 Jan 30, 2020
946c49f
fix up
charlesdong1991 Jan 30, 2020
73b01c6
imports
charlesdong1991 Jan 30, 2020
2fdb3f5
keep experimenting
charlesdong1991 Jan 30, 2020
9e52c70
fixtup
charlesdong1991 Jan 30, 2020
a366b02
add comment
charlesdong1991 Jan 30, 2020
bdfcfab
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Jan 31, 2020
36184f6
experiment
charlesdong1991 Feb 1, 2020
9d4e021
update
charlesdong1991 Feb 1, 2020
c588204
change base
charlesdong1991 Feb 1, 2020
a11279d
experiment
charlesdong1991 Feb 1, 2020
bb3ff98
experiment
charlesdong1991 Feb 1, 2020
5d0bcfd
experiment
charlesdong1991 Feb 1, 2020
cc516c8
experiemnt
charlesdong1991 Feb 1, 2020
3c5c3aa
experiment
charlesdong1991 Feb 3, 2020
a63e65d
fixup
charlesdong1991 Feb 3, 2020
4ba67e8
experiment
charlesdong1991 Feb 3, 2020
849f96f
experiment
charlesdong1991 Feb 3, 2020
50a7242
experiment
charlesdong1991 Feb 3, 2020
6635d31
experiment
charlesdong1991 Feb 3, 2020
b55b6b4
fixup and linting
charlesdong1991 Feb 3, 2020
5dd9b38
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Feb 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def _gotitem(self, key, ndim, subset=None):

cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])

cython_cast_cat_type_list = frozenset(["first", "last"])
cython_cast_keep_type_list = cython_cast_cat_type_list | frozenset(
["min", "max", "add", "prod", "ohlc"]
)

Comment on lines +95 to +99
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to specify cython func that should reserve the type

# List of aggregation/reduction functions.
# These map each group to a single numeric value
reduction_kernels = frozenset(
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,8 @@ def _cython_agg_blocks(

if result is not no_result:
# see if we can cast the block back to the original dtype
result = maybe_downcast_numeric(result, block.dtype)
if how in base.cython_cast_keep_type_list:
result = maybe_downcast_numeric(result, block.dtype)
Comment on lines +1074 to +1075
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to specify for the case when as_index=False, otherwise, will be coerced to int for those cases which they should not


if block.is_extension and isinstance(result, np.ndarray):
# e.g. block.values was an IntegerArray
Expand Down
42 changes: 28 additions & 14 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def _cumcount_array(self, ascending: bool = True):
rev[sorter] = np.arange(count, dtype=np.intp)
return out[rev].astype(np.int64, copy=False)

def _try_cast(self, result, obj, numeric_only: bool = False):
def _try_cast(self, result, obj, numeric_only: bool = False, is_python=False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, this is really ugly, the reason is to distinguish the python_agg and cython_agg since they have different situations to cast

will think a bit more

"""
Try to cast the result to our obj original type,
we may have roundtripped through object in the mean-time.
Expand All @@ -807,13 +807,19 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
dtype = obj.dtype

if not is_scalar(result):

# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.
from pandas import notna
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can import at the top


if len(result) and isinstance(result[0], dtype.type):
if (
isinstance(result[notna(result)][0], dtype.type)
and is_python
or not is_python
):
Comment on lines +818 to +822
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also ugly, it does two things: for cython_agg, if above satisfied, will cast, but for python_agg, we only cast if the not null result has the same type as original object, and I think this is the correct behaviour.

cls = dtype.construct_array_type()
result = try_cast_to_ea(cls, result, dtype=dtype)

Expand Down Expand Up @@ -871,6 +877,10 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
def _wrap_applied_output(self, keys, values, not_indexed_same: bool = False):
raise AbstractMethodError(self)

def _cython_aggregate_should_cast(self, how: str) -> bool:
should_cast = how in base.cython_cast_keep_type_list
return should_cast

def _cython_agg_general(
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
):
Expand All @@ -895,12 +905,16 @@ def _cython_agg_general(
assert len(agg_names) == result.shape[1]
for result_column, result_name in zip(result.T, agg_names):
key = base.OutputKey(label=result_name, position=idx)
output[key] = self._try_cast(result_column, obj)
if self._cython_aggregate_should_cast(how):
result_column = self._try_cast(result_column, obj)
output[key] = result_column
Comment on lines +908 to +910
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cython_agg, we should only cast if it is one of the defined cython func, otherwise, should not touch to _try_cast

idx += 1
else:
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
output[key] = self._try_cast(result, obj)
if self._cython_aggregate_should_cast(how):
result = self._try_cast(result, obj)
output[key] = result
idx += 1

if len(output) == 0:
Expand Down Expand Up @@ -936,7 +950,7 @@ def _python_agg_general(self, func, *args, **kwargs):
result, counts = self.grouper.agg_series(obj, f)
assert result is not None
key = base.OutputKey(label=name, position=idx)
output[key] = self._try_cast(result, obj, numeric_only=True)
output[key] = self._try_cast(result, obj, numeric_only=True, is_python=True)

if len(output) == 0:
return self._python_apply_general(f)
Expand All @@ -951,7 +965,7 @@ def _python_agg_general(self, func, *args, **kwargs):
if is_numeric_dtype(values.dtype):
values = ensure_float(values)

output[key] = self._try_cast(values[mask], result)
output[key] = self._try_cast(values[mask], result, is_python=True)

return self._wrap_aggregated_output(output)

Expand Down Expand Up @@ -1214,10 +1228,10 @@ def mean(self, numeric_only: bool = True):
>>> df.groupby(['A', 'B']).mean()
C
A B
1 2.0 2
4.0 1
2 3.0 1
5.0 2
1 2.0 2.0
4.0 1.0
2 3.0 1.0
5.0 2.0

Groupby one column and return the mean of only particular column in
the group.
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, grouper
from pandas.core.groupby.base import cython_cast_cat_type_list
from pandas.core.indexes.api import Index, MultiIndex, ensure_index
from pandas.core.series import Series
from pandas.core.sorting import (
Expand Down Expand Up @@ -451,7 +452,12 @@ def _cython_operation(

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values) or is_sparse(values):
# those four cython agg that should work with categoricals
if (
is_categorical_dtype(values)
and how not in cython_cast_cat_type_list
or is_sparse(values)
):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values):
if how in ["add", "prod", "cumsum", "cumprod"]:
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=True)

index = pd.Index(index, name="B")
expected = pd.Series([3, 1, 4], index=index, name="A")
expected = pd.Series([3, 1, 4], dtype="float64", index=index, name="A")
if as_index:
self.assert_series_equal(result, expected)
else:
Expand All @@ -39,7 +39,7 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=False)

index = pd.Index(index, name="B")
expected = pd.Series([1, 3, 4], index=index, name="A")
expected = pd.Series([1, 3, 4], dtype="float64", index=index, name="A")
self.assert_series_equal(result, expected)

def test_groupby_extension_transform(self, data_for_grouping):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=True)

index = pd.Index(index, name="B")
expected = pd.Series([3, 1], index=index, name="A")
expected = pd.Series([3, 1], dtype="float64", index=index, name="A")
if as_index:
self.assert_series_equal(result, expected)
else:
Expand All @@ -271,7 +271,7 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
_, index = pd.factorize(data_for_grouping, sort=False)

index = pd.Index(index, name="B")
expected = pd.Series([1, 3], index=index, name="A")
expected = pd.Series([1, 3], dtype="float64", index=index, name="A")
self.assert_series_equal(result, expected)

def test_groupby_extension_transform(self, data_for_grouping):
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ def test_uint64_type_handling(dtype, how):
expected = df.groupby("y").agg({"x": how})
df.x = df.x.astype(dtype)
result = df.groupby("y").agg({"x": how})
result.x = result.x.astype(np.int64)
if how in ["mean", "median"]:
new_dtype = np.float64
else:
new_dtype = np.int64
result.x = result.x.astype(new_dtype)
tm.assert_frame_equal(result, expected, check_exact=True)


Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/groupby/aggregate/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def test_cython_agg_empty_buckets(op, targop, observed):

g = df.groupby(pd.cut(df[0], grps), observed=observed)
expected = g.agg(lambda x: targop(x))

# when these three cases, cython_agg should cast it to float, while python_agg
# should not because it is aligned with the original type of obj
if op in ["mean", "median", "var"] and observed:
result = result.astype("int64")
tm.assert_frame_equal(result, expected)


Expand Down
24 changes: 20 additions & 4 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ def test_apply(ordered):
result = grouped.apply(lambda x: np.mean(x))
tm.assert_frame_equal(result, expected)

# we coerce back to ints
expected = expected.astype("int")
# do not coerce for mean
result = grouped.mean()
tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -314,7 +313,7 @@ def test_observed(observed):
result = groups_double_key.agg("mean")
expected = DataFrame(
{
"val": [10, 30, 20, 40],
"val": np.array([10, 30, 20, 40], dtype="float64"),
"cat": Categorical(
["a", "a", "b", "b"], categories=["a", "b", "c"], ordered=True
),
Expand Down Expand Up @@ -361,7 +360,13 @@ def test_observed_codes_remap(observed):
groups_double_key = df.groupby([values, "C2"], observed=observed)

idx = MultiIndex.from_arrays([values, [1, 2, 3, 4]], names=["cat", "C2"])
expected = DataFrame({"C1": [3, 3, 4, 5], "C3": [10, 100, 200, 34]}, index=idx)
expected = DataFrame(
{
"C1": np.array([3, 3, 4, 5], dtype="float64"),
"C3": np.array([10, 100, 200, 34], dtype="float64"),
},
index=idx,
)
if not observed:
expected = cartesian_product_for_groupers(
expected, [values.values, [1, 2, 3, 4]], ["cat", "C2"]
Expand Down Expand Up @@ -1376,3 +1381,14 @@ def test_groupby_agg_non_numeric():

result = df.groupby([1, 2, 1]).nunique()
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("func", ["first", "last"])
def test_groupby_agg_categorical_first_last(func):
# GH 31450
df = pd.DataFrame({"col_num": [1, 1, 2, 3]})
df["col_cat"] = df["col_num"].astype("category")

grouped = df.groupby("col_num").agg({"col_cat": func})
expected = df.groupby("col_num").agg(func)
tm.assert_frame_equal(grouped, expected)
6 changes: 5 additions & 1 deletion pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ def test_median_empty_bins(observed):

result = df.groupby(bins, observed=observed).median()
expected = df.groupby(bins, observed=observed).agg(lambda x: x.median())
tm.assert_frame_equal(result, expected)

# there is some inconsistency issue in type based on different types, it happens
# on windows machine and linux_py36_32bit, skip it for now
if not observed:
tm.assert_frame_equal(result, expected)
Comment on lines +376 to +380
Copy link
Member Author

@charlesdong1991 charlesdong1991 Feb 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow, i encoutered some issue with type here, only running on windows machine and linux_py36_32bit, this type is not the same, i will try a bit tomorrow, but i think the result is correct here.



@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def test_groupby_keys_same_size_as_index():
)
df = pd.DataFrame([["A", 10], ["B", 15]], columns=["metric", "values"], index=index)
result = df.groupby([pd.Grouper(level=0, freq=freq), "metric"]).mean()
expected = df.set_index([df.index, "metric"])
expected = df.set_index([df.index, "metric"]).astype("float64")

tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1295,7 +1295,7 @@ def test_groupby_2d_malformed():
d["ones"] = [1, 1]
d["label"] = ["l1", "l2"]
tmp = d.groupby(["group"]).mean()
res_values = np.array([[0, 1], [0, 1]], dtype=np.int64)
res_values = np.array([[0, 1], [0, 1]], dtype=np.float64)
tm.assert_index_equal(tmp.columns, Index(["zeros", "ones"]))
tm.assert_numpy_array_equal(tmp.values, res_values)

Expand Down Expand Up @@ -2034,7 +2034,7 @@ def test_groupby_crash_on_nunique(axis):

def test_groupby_list_level():
# GH 9790
expected = pd.DataFrame(np.arange(0, 9).reshape(3, 3))
expected = pd.DataFrame(np.arange(0, 9).reshape(3, 3), dtype="float64")
result = expected.groupby(level=[0]).mean()
tm.assert_frame_equal(result, expected)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/io/formats/test_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_to_csv_date_format(self):
df_sec["B"] = 0
df_sec["C"] = 1

expected_rows = ["A,B,C", "2013-01-01,0,1"]
expected_rows = ["A,B,C", "2013-01-01,0,1.0"]
expected_ymd_sec = tm.convert_rows_list_to_csv_str(expected_rows)

df_sec_grouped = df_sec.groupby([pd.Grouper(key="A", freq="1h"), "B"])
Expand Down
14 changes: 9 additions & 5 deletions pandas/tests/resample/test_datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def test_nanosecond_resample_error():
result = r.agg("mean")

exp_indx = pd.date_range(start=pd.to_datetime(exp_start), periods=10, freq="100n")
exp = Series(range(len(exp_indx)), index=exp_indx)
exp = Series(range(len(exp_indx)), index=exp_indx, dtype="float64")

tm.assert_series_equal(result, exp)

Expand Down Expand Up @@ -1062,7 +1062,7 @@ def test_resample_median_bug_1688():
exp = df.asfreq("T")
tm.assert_frame_equal(result, exp)

result = df.resample("T").median()
result = df.resample("T").apply(lambda x: x.median())
exp = df.asfreq("T")
tm.assert_frame_equal(result, exp)

Expand Down Expand Up @@ -1456,15 +1456,15 @@ def test_resample_with_nat():
index_1s = DatetimeIndex(
["1970-01-01 00:00:00", "1970-01-01 00:00:01", "1970-01-01 00:00:02"]
)
frame_1s = DataFrame([3, 7, 11], index=index_1s)
frame_1s = DataFrame([3, 7, 11], index=index_1s, dtype="float64")
tm.assert_frame_equal(frame.resample("1s").mean(), frame_1s)

index_2s = DatetimeIndex(["1970-01-01 00:00:00", "1970-01-01 00:00:02"])
frame_2s = DataFrame([5, 11], index=index_2s)
frame_2s = DataFrame([5, 11], index=index_2s, dtype="float64")
tm.assert_frame_equal(frame.resample("2s").mean(), frame_2s)

index_3s = DatetimeIndex(["1970-01-01 00:00:00"])
frame_3s = DataFrame([7], index=index_3s)
frame_3s = DataFrame([7], index=index_3s, dtype="float64")
tm.assert_frame_equal(frame.resample("3s").mean(), frame_3s)

tm.assert_frame_equal(frame.resample("60s").mean(), frame_3s)
Expand Down Expand Up @@ -1509,6 +1509,10 @@ def f(data, add_arg):
df = pd.DataFrame({"A": 1, "B": 2}, index=pd.date_range("2017", periods=10))
result = df.groupby("A").resample("D").agg(f, multiplier)
expected = df.groupby("A").resample("D").mean().multiply(multiplier)

# GH 31450 cython_agg will keep float for mean, python_agg will cast to the
# type of obj
expected = expected.astype("int64")
tm.assert_frame_equal(result, expected)


Expand Down
10 changes: 6 additions & 4 deletions pandas/tests/resample/test_period_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_with_local_timezone_pytz(self):
# Index is moved back a day with the timezone conversion from UTC to
# Pacific
expected_index = pd.period_range(start=start, end=end, freq="D") - offsets.Day()
expected = Series(1, index=expected_index)
expected = Series(1, index=expected_index, dtype="float64")
tm.assert_series_equal(result, expected)

def test_resample_with_pytz(self):
Expand All @@ -272,7 +272,9 @@ def test_resample_with_pytz(self):
)
result = s.resample("D").mean()
expected = Series(
2, index=pd.DatetimeIndex(["2017-01-01", "2017-01-02"], tz="US/Eastern")
2,
index=pd.DatetimeIndex(["2017-01-01", "2017-01-02"], tz="US/Eastern"),
dtype="float64",
)
tm.assert_series_equal(result, expected)
# Especially assert that the timezone is LMT for pytz
Expand Down Expand Up @@ -302,7 +304,7 @@ def test_with_local_timezone_dateutil(self):
expected_index = (
pd.period_range(start=start, end=end, freq="D", name="idx") - offsets.Day()
)
expected = Series(1, index=expected_index)
expected = Series(1, index=expected_index, dtype="float64")
tm.assert_series_equal(result, expected)

def test_resample_nonexistent_time_bin_edge(self):
Expand Down Expand Up @@ -797,7 +799,7 @@ def test_resample_with_nat(self, periods, values, freq, expected_values):
expected_index = period_range(
"1970-01-01 00:00:00", periods=len(expected_values), freq=freq
)
expected = DataFrame(expected_values, index=expected_index)
expected = DataFrame(expected_values, index=expected_index, dtype="float64")
result = frame.resample(freq).mean()
tm.assert_frame_equal(result, expected)

Expand Down
Loading