diff --git a/pandas/core/base.py b/pandas/core/base.py index 8c8037091559d..f236fea93278c 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -634,20 +634,19 @@ def _is_builtin_func(self, arg): class ShallowMixin: _attributes = [] # type: List[str] - def _shallow_copy(self, obj=None, obj_type=None, **kwargs): + def _shallow_copy(self, obj=None, **kwargs): """ return a new object with the replacement attributes """ if obj is None: obj = self._selected_obj.copy() - if obj_type is None: - obj_type = self._constructor - if isinstance(obj, obj_type): + + if isinstance(obj, self._constructor): obj = obj.obj for attr in self._attributes: if attr not in kwargs: kwargs[attr] = getattr(self, attr) - return obj_type(obj, **kwargs) + return self._constructor(obj, **kwargs) class IndexOpsMixin: diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index fc3bb69afd0cb..fed387cbeade4 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -11,22 +11,6 @@ class GroupByMixin: Provide the groupby facilities to the mixed object. """ - @staticmethod - def _dispatch(name, *args, **kwargs): - """ - Dispatch to apply. - """ - - def outer(self, *args, **kwargs): - def f(x): - x = self._shallow_copy(x, groupby=self._groupby) - return getattr(x, name)(*args, **kwargs) - - return self._groupby.apply(f) - - outer.__name__ = name - return outer - def _gotitem(self, key, ndim, subset=None): """ Sub-classes to define. Return a sliced object. diff --git a/pandas/core/window/common.py b/pandas/core/window/common.py index 2ad5a1eb6faed..3fd567f97edae 100644 --- a/pandas/core/window/common.py +++ b/pandas/core/window/common.py @@ -26,7 +26,23 @@ """ -class _GroupByMixin(GroupByMixin): +def _dispatch(name: str, *args, **kwargs): + """ + Dispatch to apply. + """ + + def outer(self, *args, **kwargs): + def f(x): + x = self._shallow_copy(x, groupby=self._groupby) + return getattr(x, name)(*args, **kwargs) + + return self._groupby.apply(f) + + outer.__name__ = name + return outer + + +class WindowGroupByMixin(GroupByMixin): """ Provide the groupby facilities. """ @@ -41,9 +57,9 @@ def __init__(self, obj, *args, **kwargs): self._groupby.grouper.mutated = True super().__init__(obj, *args, **kwargs) - count = GroupByMixin._dispatch("count") - corr = GroupByMixin._dispatch("corr", other=None, pairwise=None) - cov = GroupByMixin._dispatch("cov", other=None, pairwise=None) + count = _dispatch("count") + corr = _dispatch("corr", other=None, pairwise=None) + cov = _dispatch("cov", other=None, pairwise=None) def _apply( self, func, name=None, window=None, center=None, check_minp=None, **kwargs @@ -53,6 +69,7 @@ def _apply( performing the original function call on the grouped object. """ + # TODO: can we de-duplicate with _dispatch? def f(x, name=name, *args): x = self._shallow_copy(x) diff --git a/pandas/core/window/expanding.py b/pandas/core/window/expanding.py index 47bd8f2ec593b..55389d2fc7d9f 100644 --- a/pandas/core/window/expanding.py +++ b/pandas/core/window/expanding.py @@ -3,7 +3,7 @@ from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender, Substitution -from pandas.core.window.common import _doc_template, _GroupByMixin, _shared_docs +from pandas.core.window.common import WindowGroupByMixin, _doc_template, _shared_docs from pandas.core.window.rolling import _Rolling_and_Expanding @@ -250,7 +250,7 @@ def corr(self, other=None, pairwise=None, **kwargs): return super().corr(other=other, pairwise=pairwise, **kwargs) -class ExpandingGroupby(_GroupByMixin, Expanding): +class ExpandingGroupby(WindowGroupByMixin, Expanding): """ Provide a expanding groupby implementation. """ diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index bf5ea9c457e8a..75f9a1c628d72 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -39,9 +39,9 @@ import pandas.core.common as com from pandas.core.index import Index, ensure_index from pandas.core.window.common import ( + WindowGroupByMixin, _doc_template, _flex_binary_moment, - _GroupByMixin, _offset, _require_min_periods, _shared_docs, @@ -1917,7 +1917,7 @@ def corr(self, other=None, pairwise=None, **kwargs): Rolling.__doc__ = Window.__doc__ -class RollingGroupby(_GroupByMixin, Rolling): +class RollingGroupby(WindowGroupByMixin, Rolling): """ Provide a rolling groupby implementation. """