Skip to content

Commit 4119e04

Browse files
committed
Merge pull request #6570 from hayd/groupby_selected_obj
FIX use selected_obj rather the obj throughout groupby
2 parents 48a6849 + 0c10dad commit 4119e04

File tree

2 files changed

+78
-30
lines changed

2 files changed

+78
-30
lines changed

pandas/core/groupby.py

+40-30
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class GroupBy(PandasObject):
208208
Number of groups
209209
"""
210210
_apply_whitelist = _common_apply_whitelist
211+
_internal_names = ['_cache']
212+
_internal_names_set = set(_internal_names)
211213

212214
def __init__(self, obj, keys=None, axis=0, level=None,
213215
grouper=None, exclusions=None, selection=None, as_index=True,
@@ -288,10 +290,12 @@ def _local_dir(self):
288290
return sorted(set(self.obj._local_dir() + list(self._apply_whitelist)))
289291

290292
def __getattr__(self, attr):
293+
if attr in self._internal_names_set:
294+
return object.__getattribute__(self, attr)
291295
if attr in self.obj:
292296
return self[attr]
293297

294-
if hasattr(self.obj, attr) and attr != '_cache':
298+
if hasattr(self.obj, attr):
295299
return self._make_wrapper(attr)
296300

297301
raise AttributeError("%r object has no attribute %r" %
@@ -302,18 +306,18 @@ def __getitem__(self, key):
302306

303307
def _make_wrapper(self, name):
304308
if name not in self._apply_whitelist:
305-
is_callable = callable(getattr(self.obj, name, None))
309+
is_callable = callable(getattr(self._selected_obj, name, None))
306310
kind = ' callable ' if is_callable else ' '
307311
msg = ("Cannot access{0}attribute {1!r} of {2!r} objects, try "
308312
"using the 'apply' method".format(kind, name,
309313
type(self).__name__))
310314
raise AttributeError(msg)
311315

312-
f = getattr(self.obj, name)
316+
f = getattr(self._selected_obj, name)
313317
if not isinstance(f, types.MethodType):
314318
return self.apply(lambda self: getattr(self, name))
315319

316-
f = getattr(type(self.obj), name)
320+
f = getattr(type(self._selected_obj), name)
317321

318322
def wrapper(*args, **kwargs):
319323
# a little trickery for aggregation functions that need an axis
@@ -362,7 +366,7 @@ def get_group(self, name, obj=None):
362366
group : type of obj
363367
"""
364368
if obj is None:
365-
obj = self.obj
369+
obj = self._selected_obj
366370

367371
inds = self._get_index(name)
368372
return obj.take(inds, axis=self.axis, convert=False)
@@ -424,7 +428,8 @@ def f(g):
424428
return self._python_apply_general(f)
425429

426430
def _python_apply_general(self, f):
427-
keys, values, mutated = self.grouper.apply(f, self.obj, self.axis)
431+
keys, values, mutated = self.grouper.apply(f, self._selected_obj,
432+
self.axis)
428433

429434
return self._wrap_applied_output(keys, values,
430435
not_indexed_same=mutated)
@@ -437,7 +442,7 @@ def agg(self, func, *args, **kwargs):
437442
return self.aggregate(func, *args, **kwargs)
438443

439444
def _iterate_slices(self):
440-
yield self.name, self.obj
445+
yield self.name, self._selected_obj
441446

442447
def transform(self, func, *args, **kwargs):
443448
raise NotImplementedError
@@ -573,7 +578,7 @@ def nth(self, n, dropna=None):
573578
return self._selected_obj[is_nth]
574579

575580
if (isinstance(self._selected_obj, DataFrame)
576-
and dropna not in ['any', 'all']):
581+
and dropna not in ['any', 'all']):
577582
# Note: when agg-ing picker doesn't raise this, just returns NaN
578583
raise ValueError("For a DataFrame groupby, dropna must be "
579584
"either None, 'any' or 'all', "
@@ -582,6 +587,7 @@ def nth(self, n, dropna=None):
582587
# old behaviour, but with all and any support for DataFrames.
583588

584589
max_len = n if n >= 0 else - 1 - n
590+
585591
def picker(x):
586592
x = x.dropna(how=dropna) # Note: how is ignored if Series
587593
if len(x) <= max_len:
@@ -591,7 +597,6 @@ def picker(x):
591597

592598
return self.agg(picker)
593599

594-
595600
def cumcount(self, **kwargs):
596601
"""
597602
Number each item in each group from 0 to the length of that group - 1.
@@ -638,7 +643,7 @@ def cumcount(self, **kwargs):
638643
"""
639644
ascending = kwargs.pop('ascending', True)
640645

641-
index = self.obj.index
646+
index = self._selected_obj.index
642647
cumcounts = self._cumcount_array(ascending=ascending)
643648
return Series(cumcounts, index)
644649

@@ -706,8 +711,9 @@ def _cumcount_array(self, arr=None, **kwargs):
706711
if arr is None:
707712
arr = np.arange(self.grouper._max_groupsize, dtype='int64')
708713

709-
len_index = len(self.obj.index)
714+
len_index = len(self._selected_obj.index)
710715
cumcounts = np.empty(len_index, dtype=arr.dtype)
716+
711717
if ascending:
712718
for v in self.indices.values():
713719
cumcounts[v] = arr[:len(v)]
@@ -722,15 +728,15 @@ def _selected_obj(self):
722728
return self.obj
723729
else:
724730
return self.obj[self._selection]
725-
731+
726732
def _index_with_as_index(self, b):
727733
"""
728734
Take boolean mask of index to be returned from apply, if as_index=True
729735
730736
"""
731737
# TODO perf, it feels like this should already be somewhere...
732738
from itertools import chain
733-
original = self.obj.index
739+
original = self._selected_obj.index
734740
gp = self.grouper
735741
levels = chain((gp.levels[i][gp.labels[i][b]]
736742
for i in range(len(gp.groupings))),
@@ -812,7 +818,7 @@ def _concat_objects(self, keys, values, not_indexed_same=False):
812818

813819
if not not_indexed_same:
814820
result = concat(values, axis=self.axis)
815-
ax = self.obj._get_axis(self.axis)
821+
ax = self._selected_obj._get_axis(self.axis)
816822

817823
if isinstance(result, Series):
818824
result = result.reindex(ax)
@@ -835,14 +841,14 @@ def _apply_filter(self, indices, dropna):
835841
else:
836842
indices = np.sort(np.concatenate(indices))
837843
if dropna:
838-
filtered = self.obj.take(indices)
844+
filtered = self._selected_obj.take(indices)
839845
else:
840-
mask = np.empty(len(self.obj.index), dtype=bool)
846+
mask = np.empty(len(self._selected_obj.index), dtype=bool)
841847
mask.fill(False)
842848
mask[indices.astype(int)] = True
843849
# mask fails to broadcast when passed to where; broadcast manually.
844-
mask = np.tile(mask, list(self.obj.shape[1:]) + [1]).T
845-
filtered = self.obj.where(mask) # Fill with NaNs.
850+
mask = np.tile(mask, list(self._selected_obj.shape[1:]) + [1]).T
851+
filtered = self._selected_obj.where(mask) # Fill with NaNs.
846852
return filtered
847853

848854

@@ -1908,7 +1914,7 @@ def transform(self, func, *args, **kwargs):
19081914
-------
19091915
transformed : Series
19101916
"""
1911-
result = self.obj.copy()
1917+
result = self._selected_obj.copy()
19121918
if hasattr(result, 'values'):
19131919
result = result.values
19141920
dtype = result.dtype
@@ -1933,8 +1939,8 @@ def transform(self, func, *args, **kwargs):
19331939

19341940
# downcast if we can (and need)
19351941
result = _possibly_downcast_to_dtype(result, dtype)
1936-
return self.obj.__class__(result, index=self.obj.index,
1937-
name=self.obj.name)
1942+
return self._selected_obj.__class__(result, index=self._selected_obj.index,
1943+
name=self._selected_obj.name)
19381944

19391945
def filter(self, func, dropna=True, *args, **kwargs):
19401946
"""
@@ -2082,7 +2088,7 @@ def aggregate(self, arg, *args, **kwargs):
20822088
if self.axis != 0: # pragma: no cover
20832089
raise ValueError('Can only pass dict with axis=0')
20842090

2085-
obj = self.obj
2091+
obj = self._selected_obj
20862092

20872093
if any(isinstance(x, (list, tuple, dict)) for x in arg.values()):
20882094
new_arg = OrderedDict()
@@ -2095,7 +2101,7 @@ def aggregate(self, arg, *args, **kwargs):
20952101

20962102
keys = []
20972103
if self._selection is not None:
2098-
subset = obj[self._selection]
2104+
subset = obj
20992105
if isinstance(subset, DataFrame):
21002106
raise NotImplementedError
21012107

@@ -2294,7 +2300,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
22942300

22952301
if isinstance(v, (np.ndarray, Series)):
22962302
if isinstance(v, Series):
2297-
applied_index = self.obj._get_axis(self.axis)
2303+
applied_index = self._selected_obj._get_axis(self.axis)
22982304
all_indexed_same = _all_indexes_same([
22992305
x.index for x in values
23002306
])
@@ -2367,7 +2373,11 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
23672373

23682374
# if we have date/time like in the original, then coerce dates
23692375
# as we are stacking can easily have object dtypes here
2370-
cd = 'coerce' if self.obj.ndim == 2 and self.obj.dtypes.isin(_DATELIKE_DTYPES).any() else True
2376+
if (self._selected_obj.ndim == 2
2377+
and self._selected_obj.dtypes.isin(_DATELIKE_DTYPES).any()):
2378+
cd = 'coerce'
2379+
else:
2380+
cd = True
23712381
return result.convert_objects(convert_dates=cd)
23722382

23732383
else:
@@ -2668,8 +2678,8 @@ def _wrap_agged_blocks(self, blocks):
26682678
return result.convert_objects()
26692679

26702680
def _iterate_column_groupbys(self):
2671-
for i, colname in enumerate(self.obj.columns):
2672-
yield colname, SeriesGroupBy(self.obj.iloc[:, i],
2681+
for i, colname in enumerate(self._selected_obj.columns):
2682+
yield colname, SeriesGroupBy(self._selected_obj.iloc[:, i],
26732683
selection=colname,
26742684
grouper=self.grouper,
26752685
exclusions=self.exclusions)
@@ -2679,7 +2689,7 @@ def _apply_to_column_groupbys(self, func):
26792689
return concat(
26802690
(func(col_groupby) for _, col_groupby
26812691
in self._iterate_column_groupbys()),
2682-
keys=self.obj.columns, axis=1)
2692+
keys=self._selected_obj.columns, axis=1)
26832693

26842694
def ohlc(self):
26852695
"""
@@ -2701,10 +2711,10 @@ def _iterate_slices(self):
27012711
if self.axis == 0:
27022712
# kludge
27032713
if self._selection is None:
2704-
slice_axis = self.obj.items
2714+
slice_axis = self._selected_obj.items
27052715
else:
27062716
slice_axis = self._selection_list
2707-
slicer = lambda x: self.obj[x]
2717+
slicer = lambda x: self._selected_obj[x]
27082718
else:
27092719
raise NotImplementedError
27102720

pandas/tests/test_groupby.py

+38
Original file line numberDiff line numberDiff line change
@@ -3466,6 +3466,44 @@ def test_index_label_overlaps_location(self):
34663466
expected = ser.take([1, 3, 4])
34673467
assert_series_equal(actual, expected)
34683468

3469+
def test_groupby_selection_with_methods(self):
3470+
# some methods which require DatetimeIndex
3471+
rng = pd.date_range('2014', periods=len(self.df))
3472+
self.df.index = rng
3473+
3474+
g = self.df.groupby(['A'])[['C']]
3475+
g_exp = self.df[['C']].groupby(self.df['A'])
3476+
# TODO check groupby with > 1 col ?
3477+
3478+
# methods which are called as .foo()
3479+
methods = ['count',
3480+
'corr',
3481+
'cummax', 'cummin', 'cumprod',
3482+
'describe', 'rank',
3483+
'quantile',
3484+
'diff', 'shift',
3485+
'all', 'any',
3486+
'idxmin', 'idxmax',
3487+
'ffill', 'bfill',
3488+
'pct_change',
3489+
'tshift'
3490+
]
3491+
3492+
for m in methods:
3493+
res = getattr(g, m)()
3494+
exp = getattr(g_exp, m)()
3495+
assert_frame_equal(res, exp) # should always be frames!
3496+
3497+
# methods which aren't just .foo()
3498+
assert_frame_equal(g.fillna(0), g_exp.fillna(0))
3499+
assert_frame_equal(g.dtypes, g_exp.dtypes)
3500+
assert_frame_equal(g.apply(lambda x: x.sum()),
3501+
g_exp.apply(lambda x: x.sum()))
3502+
3503+
assert_frame_equal(g.resample('D'), g_exp.resample('D'))
3504+
3505+
3506+
34693507
def test_groupby_whitelist(self):
34703508
from string import ascii_lowercase
34713509
letters = np.array(list(ascii_lowercase))

0 commit comments

Comments
 (0)