diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 9c5806a3fe945..6302499b6d153 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1,6 +1,6 @@ import abc import inspect -from typing import TYPE_CHECKING, Iterator, Type +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, Union import numpy as np @@ -18,6 +18,8 @@ if TYPE_CHECKING: from pandas import DataFrame, Series, Index +ResType = Dict[int, Any] + def frame_apply( obj: "DataFrame", @@ -64,10 +66,15 @@ def result_index(self) -> "Index": def result_columns(self) -> "Index": pass + @property @abc.abstractmethod def series_generator(self) -> Iterator["Series"]: pass + @abc.abstractmethod + def wrap_results_for_axis(self, results: ResType) -> Union["Series", "DataFrame"]: + pass + # --------------------------------------------------------------- def __init__( @@ -107,8 +114,16 @@ def f(x): # results self.result = None - self.res_index = None - self.res_columns = None + self._res_index: Optional["Index"] = None + + @property + def res_index(self) -> "Index": + assert self._res_index is not None + return self._res_index + + @property + def res_columns(self) -> "Index": + return self.result_columns @property def columns(self) -> "Index": @@ -298,12 +313,12 @@ def apply_standard(self): return self.obj._constructor_sliced(result, index=labels) # compute the result using the series generator - self.apply_series_generator() + results = self.apply_series_generator() # wrap results - return self.wrap_results() + return self.wrap_results(results) - def apply_series_generator(self): + def apply_series_generator(self) -> ResType: series_gen = self.series_generator res_index = self.result_index @@ -330,17 +345,15 @@ def apply_series_generator(self): results[i] = self.f(v) keys.append(v.name) - self.results = results - self.res_index = res_index - self.res_columns = self.result_columns + self._res_index = res_index + return results - def wrap_results(self): - results = self.results + def wrap_results(self, results: ResType) -> Union["Series", "DataFrame"]: # see if we can infer the results if len(results) > 0 and 0 in results and is_sequence(results[0]): - return self.wrap_results_for_axis() + return self.wrap_results_for_axis(results) # dict of scalars result = self.obj._constructor_sliced(results) @@ -367,10 +380,9 @@ def result_index(self) -> "Index": def result_columns(self) -> "Index": return self.index - def wrap_results_for_axis(self): + def wrap_results_for_axis(self, results: ResType) -> "DataFrame": """ return the results for the rows """ - results = self.results result = self.obj._constructor(data=results) if not isinstance(results[0], ABCSeries): @@ -406,13 +418,13 @@ def result_index(self) -> "Index": def result_columns(self) -> "Index": return self.columns - def wrap_results_for_axis(self): + def wrap_results_for_axis(self, results: ResType) -> Union["Series", "DataFrame"]: """ return the results for the columns """ - results = self.results + result: Union["Series", "DataFrame"] # we have requested to expand if self.result_type == "expand": - result = self.infer_to_same_shape() + result = self.infer_to_same_shape(results) # we have a non-series and don't want inference elif not isinstance(results[0], ABCSeries): @@ -423,13 +435,12 @@ def wrap_results_for_axis(self): # we may want to infer results else: - result = self.infer_to_same_shape() + result = self.infer_to_same_shape(results) return result - def infer_to_same_shape(self) -> "DataFrame": + def infer_to_same_shape(self, results: ResType) -> "DataFrame": """ infer the results to the same shape as the input object """ - results = self.results result = self.obj._constructor(data=results) result = result.T