Skip to content

Commit 6421b0b

Browse files
committed
Handle categorical values
Fixes old errors on production
1 parent 899aad1 commit 6421b0b

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

groupby.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,16 @@ def groupby(
294294
# (hopefully) the least computationally-intense.
295295
agg_sets = "size"
296296

297+
# Got categoricals? Order the categories, so min/max work
298+
category_colnames = {
299+
agg.colname
300+
for agg in aggregations
301+
if agg.operation in {Operation.MIN, Operation.MAX}
302+
and hasattr(table[colname], "cat")
303+
}
304+
for colname in category_colnames:
305+
table[colname] = table[colname].cat.as_ordered()
306+
297307
if group_specs:
298308
# aggs: DataFrame indexed by group
299309
# out: just the group colnames, no values yet (we'll add them later)
@@ -336,6 +346,14 @@ def groupby(
336346
except AttributeError:
337347
out[outname] = series
338348

349+
# Remember those category colnames we converted to ordered? Now we need to
350+
# undo that (and remove newly-unused categories).
351+
for colname in out.columns:
352+
column = out[colname]
353+
if hasattr(column, "cat") and column.cat.ordered:
354+
column.cat.remove_unused_categories(inplace=True)
355+
column.cat.as_unordered(inplace=True)
356+
339357
return out
340358

341359

test_groupby.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_aggregate_numbers(self):
341341
),
342342
)
343343

344-
def test_aggregate_strings(self):
344+
def test_aggregate_text_values(self):
345345
result = groupby(
346346
pd.DataFrame({"A": [1, 1, 1], "B": ["a", "b", "a"]}),
347347
[Group("A", None)],
@@ -367,6 +367,60 @@ def test_aggregate_strings(self):
367367
),
368368
)
369369

370+
def test_aggregate_text_category_values(self):
371+
result = groupby(
372+
pd.DataFrame(
373+
{"A": [1, 1, 1], "B": pd.Series(["a", "b", "a"], dtype="category")}
374+
),
375+
[Group("A", None)],
376+
[
377+
Aggregation(Operation.SIZE, "B", "size"),
378+
Aggregation(Operation.NUNIQUE, "B", "nunique"),
379+
Aggregation(Operation.MIN, "B", "min"),
380+
Aggregation(Operation.MAX, "B", "max"),
381+
Aggregation(Operation.FIRST, "B", "first"),
382+
],
383+
)
384+
assert_frame_equal(
385+
result,
386+
pd.DataFrame(
387+
{
388+
"A": [1],
389+
"size": [3],
390+
"nunique": [2],
391+
"min": pd.Series(["a"], dtype="category"),
392+
"max": pd.Series(["b"], dtype="category"),
393+
"first": pd.Series(["a"], dtype="category"),
394+
}
395+
),
396+
)
397+
398+
def test_aggregate_text_category_values_empty_still_has_object_dtype(self):
399+
result = groupby(
400+
pd.DataFrame({"A": [None]}, dtype=str).astype("category"),
401+
[Group("A", None)],
402+
[
403+
Aggregation(Operation.SIZE, "A", "size"),
404+
Aggregation(Operation.NUNIQUE, "A", "nunique"),
405+
Aggregation(Operation.MIN, "A", "min"),
406+
Aggregation(Operation.MAX, "A", "max"),
407+
Aggregation(Operation.FIRST, "A", "first"),
408+
],
409+
)
410+
assert_frame_equal(
411+
result,
412+
pd.DataFrame(
413+
{
414+
"A": pd.Series([], dtype=str).astype("category"),
415+
"size": pd.Series([], dtype=int),
416+
"nunique": pd.Series([], dtype=int),
417+
"min": pd.Series([], dtype=str).astype("category"),
418+
"max": pd.Series([], dtype=str).astype("category"),
419+
"first": pd.Series([], dtype=str).astype("category"),
420+
}
421+
),
422+
)
423+
370424
def test_aggregate_datetime_no_granularity(self):
371425
result = groupby(
372426
pd.DataFrame({"A": [dt(2018, 1, 4), dt(2018, 1, 5), dt(2018, 1, 4)]}),

0 commit comments

Comments
 (0)