Skip to content

Commit 049bd03

Browse files
jbrockmendelproost
authored andcommitted
REF: avoid result=None case in _python_agg_general (pandas-dev#29499)
1 parent 62ee69e commit 049bd03

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

pandas/core/groupby/groupby.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ def _python_agg_general(self, func, *args, **kwargs):
898898
# iterate through "columns" ex exclusions to populate output dict
899899
output = {}
900900
for name, obj in self._iterate_slices():
901+
if self.grouper.ngroups == 0:
902+
# agg_series below assumes ngroups > 0
903+
continue
904+
901905
try:
902906
# if this function is invalid for this dtype, we will ignore it.
903907
func(obj[:0])
@@ -911,10 +915,8 @@ def _python_agg_general(self, func, *args, **kwargs):
911915
pass
912916

913917
result, counts = self.grouper.agg_series(obj, f)
914-
if result is not None:
915-
# TODO: only 3 test cases get None here, do something
916-
# in those cases
917-
output[name] = self._try_cast(result, obj, numeric_only=True)
918+
assert result is not None
919+
output[name] = self._try_cast(result, obj, numeric_only=True)
918920

919921
if len(output) == 0:
920922
return self._python_apply_general(f)

pandas/core/groupby/ops.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ def _transform(
601601
return result
602602

603603
def agg_series(self, obj: Series, func):
604+
# Caller is responsible for checking ngroups != 0
605+
assert self.ngroups != 0
606+
604607
if is_extension_array_dtype(obj.dtype) and obj.dtype.kind != "M":
605608
# _aggregate_series_fast would raise TypeError when
606609
# calling libreduction.Slider
@@ -626,8 +629,10 @@ def agg_series(self, obj: Series, func):
626629
return self._aggregate_series_pure_python(obj, func)
627630

628631
def _aggregate_series_fast(self, obj, func):
629-
# At this point we have already checked that obj.index is not a MultiIndex
630-
# and that obj is backed by an ndarray, not ExtensionArray
632+
# At this point we have already checked that
633+
# - obj.index is not a MultiIndex
634+
# - obj is backed by an ndarray, not ExtensionArray
635+
# - ngroups != 0
631636
func = self._is_builtin_func(func)
632637

633638
group_index, _, ngroups = self.group_info
@@ -660,11 +665,9 @@ def _aggregate_series_pure_python(self, obj, func):
660665
counts[label] = group.shape[0]
661666
result[label] = res
662667

663-
if result is not None:
664-
# if splitter is empty, result can be None, in which case
665-
# maybe_convert_objects would raise TypeError
666-
result = lib.maybe_convert_objects(result, try_float=0)
667-
# TODO: try_cast back to EA?
668+
assert result is not None
669+
result = lib.maybe_convert_objects(result, try_float=0)
670+
# TODO: try_cast back to EA?
668671

669672
return result, counts
670673

@@ -815,6 +818,9 @@ def groupings(self):
815818
]
816819

817820
def agg_series(self, obj: Series, func):
821+
# Caller is responsible for checking ngroups != 0
822+
assert self.ngroups != 0
823+
818824
if is_extension_array_dtype(obj.dtype):
819825
# pre-empty SeriesBinGrouper from raising TypeError
820826
# TODO: watch out, this can return None

0 commit comments

Comments
 (0)