Skip to content

Commit 6ed42d7

Browse files
Fix type mismatch in groupby reduction for empty objects (#13942)
closes #13941 This PR preserves the column types, for group by reduction operations that are performed on empty objects. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Bradley Dice (https://github.com/bdice) URL: #13942
1 parent 171fc91 commit 6ed42d7

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

python/cudf/cudf/core/groupby/groupby.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,8 @@ def agg(self, func):
553553
orig_dtypes,
554554
):
555555
for agg, col in zip(aggs, cols):
556+
agg_name = agg.__name__ if callable(agg) else agg
556557
if multilevel:
557-
agg_name = agg.__name__ if callable(agg) else agg
558558
key = (col_name, agg_name)
559559
else:
560560
key = col_name
@@ -564,7 +564,26 @@ def agg(self, func):
564564
):
565565
# Structs lose their labels which we reconstruct here
566566
col = col._with_type_metadata(cudf.ListDtype(orig_dtype))
567-
data[key] = col
567+
568+
if (
569+
self.obj.empty
570+
and (
571+
isinstance(agg_name, str)
572+
and agg_name in Reducible._SUPPORTED_REDUCTIONS
573+
)
574+
and len(col) == 0
575+
and not isinstance(
576+
col,
577+
(
578+
cudf.core.column.ListColumn,
579+
cudf.core.column.StructColumn,
580+
cudf.core.column.DecimalBaseColumn,
581+
),
582+
)
583+
):
584+
data[key] = col.astype(orig_dtype)
585+
else:
586+
data[key] = col
568587
data = ColumnAccessor(data, multiindex=multilevel)
569588
if not multilevel:
570589
data = data.rename_levels({np.nan: None}, level=0)

python/cudf/cudf/tests/test_groupby.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,6 +3342,33 @@ def test_group_by_pandas_sort_order(groups, sort):
33423342
)
33433343

33443344

3345+
@pytest.mark.parametrize(
3346+
"dtype",
3347+
["int32", "int64", "float64", "datetime64[ns]", "timedelta64[ns]", "bool"],
3348+
)
3349+
@pytest.mark.parametrize(
3350+
"reduce_op",
3351+
[
3352+
"min",
3353+
"max",
3354+
"idxmin",
3355+
"idxmax",
3356+
"first",
3357+
"last",
3358+
],
3359+
)
3360+
def test_group_by_empty_reduction(dtype, reduce_op):
3361+
gdf = cudf.DataFrame({"a": [], "b": [], "c": []}, dtype=dtype)
3362+
pdf = gdf.to_pandas()
3363+
3364+
gg = gdf.groupby("a")["c"]
3365+
pg = pdf.groupby("a")["c"]
3366+
3367+
assert_eq(
3368+
getattr(gg, reduce_op)(), getattr(pg, reduce_op)(), check_dtype=True
3369+
)
3370+
3371+
33453372
@pytest.mark.parametrize(
33463373
"dtype",
33473374
["int32", "int64", "float64", "datetime64[ns]", "timedelta64[ns]", "bool"],
@@ -3357,6 +3384,7 @@ def test_group_by_empty_apply(request, dtype, apply_op):
33573384
reason=("sum isn't supported for datetime64[ns]"),
33583385
)
33593386
)
3387+
33603388
gdf = cudf.DataFrame({"a": [], "b": [], "c": []}, dtype=dtype)
33613389
pdf = gdf.to_pandas()
33623390

0 commit comments

Comments
 (0)