Skip to content

FIX use selected_obj rather the obj throughout groupby #6570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 9, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 40 additions & 30 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class GroupBy(PandasObject):
Number of groups
"""
_apply_whitelist = _common_apply_whitelist
_internal_names = ['_cache']
_internal_names_set = set(_internal_names)

def __init__(self, obj, keys=None, axis=0, level=None,
grouper=None, exclusions=None, selection=None, as_index=True,
Expand Down Expand Up @@ -288,10 +290,12 @@ def _local_dir(self):
return sorted(set(self.obj._local_dir() + list(self._apply_whitelist)))

def __getattr__(self, attr):
if attr in self._internal_names_set:
return object.__getattribute__(self, attr)
if attr in self.obj:
return self[attr]

if hasattr(self.obj, attr) and attr != '_cache':
if hasattr(self.obj, attr):
return self._make_wrapper(attr)

raise AttributeError("%r object has no attribute %r" %
Expand All @@ -302,18 +306,18 @@ def __getitem__(self, key):

def _make_wrapper(self, name):
if name not in self._apply_whitelist:
is_callable = callable(getattr(self.obj, name, None))
is_callable = callable(getattr(self._selected_obj, name, None))
kind = ' callable ' if is_callable else ' '
msg = ("Cannot access{0}attribute {1!r} of {2!r} objects, try "
"using the 'apply' method".format(kind, name,
type(self).__name__))
raise AttributeError(msg)

f = getattr(self.obj, name)
f = getattr(self._selected_obj, name)
if not isinstance(f, types.MethodType):
return self.apply(lambda self: getattr(self, name))

f = getattr(type(self.obj), name)
f = getattr(type(self._selected_obj), name)

def wrapper(*args, **kwargs):
# a little trickery for aggregation functions that need an axis
Expand Down Expand Up @@ -362,7 +366,7 @@ def get_group(self, name, obj=None):
group : type of obj
"""
if obj is None:
obj = self.obj
obj = self._selected_obj

inds = self._get_index(name)
return obj.take(inds, axis=self.axis, convert=False)
Expand Down Expand Up @@ -424,7 +428,8 @@ def f(g):
return self._python_apply_general(f)

def _python_apply_general(self, f):
keys, values, mutated = self.grouper.apply(f, self.obj, self.axis)
keys, values, mutated = self.grouper.apply(f, self._selected_obj,
self.axis)

return self._wrap_applied_output(keys, values,
not_indexed_same=mutated)
Expand All @@ -437,7 +442,7 @@ def agg(self, func, *args, **kwargs):
return self.aggregate(func, *args, **kwargs)

def _iterate_slices(self):
yield self.name, self.obj
yield self.name, self._selected_obj

def transform(self, func, *args, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -573,7 +578,7 @@ def nth(self, n, dropna=None):
return self._selected_obj[is_nth]

if (isinstance(self._selected_obj, DataFrame)
and dropna not in ['any', 'all']):
and dropna not in ['any', 'all']):
# Note: when agg-ing picker doesn't raise this, just returns NaN
raise ValueError("For a DataFrame groupby, dropna must be "
"either None, 'any' or 'all', "
Expand All @@ -582,6 +587,7 @@ def nth(self, n, dropna=None):
# old behaviour, but with all and any support for DataFrames.

max_len = n if n >= 0 else - 1 - n

def picker(x):
x = x.dropna(how=dropna) # Note: how is ignored if Series
if len(x) <= max_len:
Expand All @@ -591,7 +597,6 @@ def picker(x):

return self.agg(picker)


def cumcount(self, **kwargs):
"""
Number each item in each group from 0 to the length of that group - 1.
Expand Down Expand Up @@ -638,7 +643,7 @@ def cumcount(self, **kwargs):
"""
ascending = kwargs.pop('ascending', True)

index = self.obj.index
index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)

Expand Down Expand Up @@ -706,8 +711,9 @@ def _cumcount_array(self, arr=None, **kwargs):
if arr is None:
arr = np.arange(self.grouper._max_groupsize, dtype='int64')

len_index = len(self.obj.index)
len_index = len(self._selected_obj.index)
cumcounts = np.empty(len_index, dtype=arr.dtype)

if ascending:
for v in self.indices.values():
cumcounts[v] = arr[:len(v)]
Expand All @@ -722,15 +728,15 @@ def _selected_obj(self):
return self.obj
else:
return self.obj[self._selection]

def _index_with_as_index(self, b):
"""
Take boolean mask of index to be returned from apply, if as_index=True

"""
# TODO perf, it feels like this should already be somewhere...
from itertools import chain
original = self.obj.index
original = self._selected_obj.index
gp = self.grouper
levels = chain((gp.levels[i][gp.labels[i][b]]
for i in range(len(gp.groupings))),
Expand Down Expand Up @@ -812,7 +818,7 @@ def _concat_objects(self, keys, values, not_indexed_same=False):

if not not_indexed_same:
result = concat(values, axis=self.axis)
ax = self.obj._get_axis(self.axis)
ax = self._selected_obj._get_axis(self.axis)

if isinstance(result, Series):
result = result.reindex(ax)
Expand All @@ -835,14 +841,14 @@ def _apply_filter(self, indices, dropna):
else:
indices = np.sort(np.concatenate(indices))
if dropna:
filtered = self.obj.take(indices)
filtered = self._selected_obj.take(indices)
else:
mask = np.empty(len(self.obj.index), dtype=bool)
mask = np.empty(len(self._selected_obj.index), dtype=bool)
mask.fill(False)
mask[indices.astype(int)] = True
# mask fails to broadcast when passed to where; broadcast manually.
mask = np.tile(mask, list(self.obj.shape[1:]) + [1]).T
filtered = self.obj.where(mask) # Fill with NaNs.
mask = np.tile(mask, list(self._selected_obj.shape[1:]) + [1]).T
filtered = self._selected_obj.where(mask) # Fill with NaNs.
return filtered


Expand Down Expand Up @@ -1908,7 +1914,7 @@ def transform(self, func, *args, **kwargs):
-------
transformed : Series
"""
result = self.obj.copy()
result = self._selected_obj.copy()
if hasattr(result, 'values'):
result = result.values
dtype = result.dtype
Expand All @@ -1933,8 +1939,8 @@ def transform(self, func, *args, **kwargs):

# downcast if we can (and need)
result = _possibly_downcast_to_dtype(result, dtype)
return self.obj.__class__(result, index=self.obj.index,
name=self.obj.name)
return self._selected_obj.__class__(result, index=self._selected_obj.index,
name=self._selected_obj.name)

def filter(self, func, dropna=True, *args, **kwargs):
"""
Expand Down Expand Up @@ -2082,7 +2088,7 @@ def aggregate(self, arg, *args, **kwargs):
if self.axis != 0: # pragma: no cover
raise ValueError('Can only pass dict with axis=0')

obj = self.obj
obj = self._selected_obj

if any(isinstance(x, (list, tuple, dict)) for x in arg.values()):
new_arg = OrderedDict()
Expand All @@ -2095,7 +2101,7 @@ def aggregate(self, arg, *args, **kwargs):

keys = []
if self._selection is not None:
subset = obj[self._selection]
subset = obj
if isinstance(subset, DataFrame):
raise NotImplementedError

Expand Down Expand Up @@ -2294,7 +2300,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

if isinstance(v, (np.ndarray, Series)):
if isinstance(v, Series):
applied_index = self.obj._get_axis(self.axis)
applied_index = self._selected_obj._get_axis(self.axis)
all_indexed_same = _all_indexes_same([
x.index for x in values
])
Expand Down Expand Up @@ -2367,7 +2373,11 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

# if we have date/time like in the original, then coerce dates
# as we are stacking can easily have object dtypes here
cd = 'coerce' if self.obj.ndim == 2 and self.obj.dtypes.isin(_DATELIKE_DTYPES).any() else True
if (self._selected_obj.ndim == 2
and self._selected_obj.dtypes.isin(_DATELIKE_DTYPES).any()):
cd = 'coerce'
else:
cd = True
return result.convert_objects(convert_dates=cd)

else:
Expand Down Expand Up @@ -2668,8 +2678,8 @@ def _wrap_agged_blocks(self, blocks):
return result.convert_objects()

def _iterate_column_groupbys(self):
for i, colname in enumerate(self.obj.columns):
yield colname, SeriesGroupBy(self.obj.iloc[:, i],
for i, colname in enumerate(self._selected_obj.columns):
yield colname, SeriesGroupBy(self._selected_obj.iloc[:, i],
selection=colname,
grouper=self.grouper,
exclusions=self.exclusions)
Expand All @@ -2679,7 +2689,7 @@ def _apply_to_column_groupbys(self, func):
return concat(
(func(col_groupby) for _, col_groupby
in self._iterate_column_groupbys()),
keys=self.obj.columns, axis=1)
keys=self._selected_obj.columns, axis=1)

def ohlc(self):
"""
Expand All @@ -2701,10 +2711,10 @@ def _iterate_slices(self):
if self.axis == 0:
# kludge
if self._selection is None:
slice_axis = self.obj.items
slice_axis = self._selected_obj.items
else:
slice_axis = self._selection_list
slicer = lambda x: self.obj[x]
slicer = lambda x: self._selected_obj[x]
else:
raise NotImplementedError

Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3466,6 +3466,44 @@ def test_index_label_overlaps_location(self):
expected = ser.take([1, 3, 4])
assert_series_equal(actual, expected)

def test_groupby_selection_with_methods(self):
# some methods which require DatetimeIndex
rng = pd.date_range('2014', periods=len(self.df))
self.df.index = rng

g = self.df.groupby(['A'])[['C']]
g_exp = self.df[['C']].groupby(self.df['A'])
# TODO check groupby with > 1 col ?

# methods which are called as .foo()
methods = ['count',
'corr',
'cummax', 'cummin', 'cumprod',
'describe', 'rank',
'quantile',
'diff', 'shift',
'all', 'any',
'idxmin', 'idxmax',
'ffill', 'bfill',
'pct_change',
'tshift'
]

for m in methods:
res = getattr(g, m)()
exp = getattr(g_exp, m)()
assert_frame_equal(res, exp) # should always be frames!

# methods which aren't just .foo()
assert_frame_equal(g.fillna(0), g_exp.fillna(0))
assert_frame_equal(g.dtypes, g_exp.dtypes)
assert_frame_equal(g.apply(lambda x: x.sum()),
g_exp.apply(lambda x: x.sum()))

assert_frame_equal(g.resample('D'), g_exp.resample('D'))



def test_groupby_whitelist(self):
from string import ascii_lowercase
letters = np.array(list(ascii_lowercase))
Expand Down