Skip to content

Commit 5936bea

Browse files
haydgouthambs
authored andcommitted
FIX use selected_obj rather the obj throughout groupby
TST dont ignore subselection in groupby
1 parent 4df5bd2 commit 5936bea

File tree

2 files changed

+72
-29
lines changed

2 files changed

+72
-29
lines changed

pandas/core/groupby.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -302,18 +302,18 @@ def __getitem__(self, key):
302302

303303
def _make_wrapper(self, name):
304304
if name not in self._apply_whitelist:
305-
is_callable = callable(getattr(self.obj, name, None))
305+
is_callable = callable(getattr(self._selected_obj, name, None))
306306
kind = ' callable ' if is_callable else ' '
307307
msg = ("Cannot access{0}attribute {1!r} of {2!r} objects, try "
308308
"using the 'apply' method".format(kind, name,
309309
type(self).__name__))
310310
raise AttributeError(msg)
311311

312-
f = getattr(self.obj, name)
312+
f = getattr(self._selected_obj, name)
313313
if not isinstance(f, types.MethodType):
314314
return self.apply(lambda self: getattr(self, name))
315315

316-
f = getattr(type(self.obj), name)
316+
f = getattr(type(self._selected_obj), name)
317317

318318
def wrapper(*args, **kwargs):
319319
# a little trickery for aggregation functions that need an axis
@@ -362,7 +362,7 @@ def get_group(self, name, obj=None):
362362
group : type of obj
363363
"""
364364
if obj is None:
365-
obj = self.obj
365+
obj = self._selected_obj
366366

367367
inds = self._get_index(name)
368368
return obj.take(inds, axis=self.axis, convert=False)
@@ -424,7 +424,7 @@ def f(g):
424424
return self._python_apply_general(f)
425425

426426
def _python_apply_general(self, f):
427-
keys, values, mutated = self.grouper.apply(f, self.obj, self.axis)
427+
keys, values, mutated = self.grouper.apply(f, self._selected_obj, self.axis)
428428

429429
return self._wrap_applied_output(keys, values,
430430
not_indexed_same=mutated)
@@ -437,7 +437,7 @@ def agg(self, func, *args, **kwargs):
437437
return self.aggregate(func, *args, **kwargs)
438438

439439
def _iterate_slices(self):
440-
yield self.name, self.obj
440+
yield self.name, self._selected_obj
441441

442442
def transform(self, func, *args, **kwargs):
443443
raise NotImplementedError
@@ -573,7 +573,7 @@ def nth(self, n, dropna=None):
573573
return self._selected_obj[is_nth]
574574

575575
if (isinstance(self._selected_obj, DataFrame)
576-
and dropna not in ['any', 'all']):
576+
and dropna not in ['any', 'all']):
577577
# Note: when agg-ing picker doesn't raise this, just returns NaN
578578
raise ValueError("For a DataFrame groupby, dropna must be "
579579
"either None, 'any' or 'all', "
@@ -582,6 +582,7 @@ def nth(self, n, dropna=None):
582582
# old behaviour, but with all and any support for DataFrames.
583583

584584
max_len = n if n >= 0 else - 1 - n
585+
585586
def picker(x):
586587
x = x.dropna(how=dropna) # Note: how is ignored if Series
587588
if len(x) <= max_len:
@@ -591,7 +592,6 @@ def picker(x):
591592

592593
return self.agg(picker)
593594

594-
595595
def cumcount(self, **kwargs):
596596
"""
597597
Number each item in each group from 0 to the length of that group - 1.
@@ -638,7 +638,7 @@ def cumcount(self, **kwargs):
638638
"""
639639
ascending = kwargs.pop('ascending', True)
640640

641-
index = self.obj.index
641+
index = self._selected_obj.index
642642
cumcounts = self._cumcount_array(ascending=ascending)
643643
return Series(cumcounts, index)
644644

@@ -706,8 +706,9 @@ def _cumcount_array(self, arr=None, **kwargs):
706706
if arr is None:
707707
arr = np.arange(self.grouper._max_groupsize, dtype='int64')
708708

709-
len_index = len(self.obj.index)
709+
len_index = len(self._selected_obj.index)
710710
cumcounts = np.empty(len_index, dtype=arr.dtype)
711+
711712
if ascending:
712713
for v in self.indices.values():
713714
cumcounts[v] = arr[:len(v)]
@@ -722,15 +723,15 @@ def _selected_obj(self):
722723
return self.obj
723724
else:
724725
return self.obj[self._selection]
725-
726+
726727
def _index_with_as_index(self, b):
727728
"""
728729
Take boolean mask of index to be returned from apply, if as_index=True
729730
730731
"""
731732
# TODO perf, it feels like this should already be somewhere...
732733
from itertools import chain
733-
original = self.obj.index
734+
original = self._selected_obj.index
734735
gp = self.grouper
735736
levels = chain((gp.levels[i][gp.labels[i][b]]
736737
for i in range(len(gp.groupings))),
@@ -812,7 +813,7 @@ def _concat_objects(self, keys, values, not_indexed_same=False):
812813

813814
if not not_indexed_same:
814815
result = concat(values, axis=self.axis)
815-
ax = self.obj._get_axis(self.axis)
816+
ax = self._selected_obj._get_axis(self.axis)
816817

817818
if isinstance(result, Series):
818819
result = result.reindex(ax)
@@ -835,14 +836,14 @@ def _apply_filter(self, indices, dropna):
835836
else:
836837
indices = np.sort(np.concatenate(indices))
837838
if dropna:
838-
filtered = self.obj.take(indices)
839+
filtered = self._selected_obj.take(indices)
839840
else:
840-
mask = np.empty(len(self.obj.index), dtype=bool)
841+
mask = np.empty(len(self._selected_obj.index), dtype=bool)
841842
mask.fill(False)
842843
mask[indices.astype(int)] = True
843844
# mask fails to broadcast when passed to where; broadcast manually.
844-
mask = np.tile(mask, list(self.obj.shape[1:]) + [1]).T
845-
filtered = self.obj.where(mask) # Fill with NaNs.
845+
mask = np.tile(mask, list(self._selected_obj.shape[1:]) + [1]).T
846+
filtered = self._selected_obj.where(mask) # Fill with NaNs.
846847
return filtered
847848

848849

@@ -1908,7 +1909,7 @@ def transform(self, func, *args, **kwargs):
19081909
-------
19091910
transformed : Series
19101911
"""
1911-
result = self.obj.copy()
1912+
result = self._selected_obj.copy()
19121913
if hasattr(result, 'values'):
19131914
result = result.values
19141915
dtype = result.dtype
@@ -1933,8 +1934,8 @@ def transform(self, func, *args, **kwargs):
19331934

19341935
# downcast if we can (and need)
19351936
result = _possibly_downcast_to_dtype(result, dtype)
1936-
return self.obj.__class__(result, index=self.obj.index,
1937-
name=self.obj.name)
1937+
return self._selected_obj.__class__(result, index=self._selected_obj.index,
1938+
name=self._selected_obj.name)
19381939

19391940
def filter(self, func, dropna=True, *args, **kwargs):
19401941
"""
@@ -2082,7 +2083,7 @@ def aggregate(self, arg, *args, **kwargs):
20822083
if self.axis != 0: # pragma: no cover
20832084
raise ValueError('Can only pass dict with axis=0')
20842085

2085-
obj = self.obj
2086+
obj = self._selected_obj
20862087

20872088
if any(isinstance(x, (list, tuple, dict)) for x in arg.values()):
20882089
new_arg = OrderedDict()
@@ -2095,7 +2096,7 @@ def aggregate(self, arg, *args, **kwargs):
20952096

20962097
keys = []
20972098
if self._selection is not None:
2098-
subset = obj[self._selection]
2099+
subset = obj
20992100
if isinstance(subset, DataFrame):
21002101
raise NotImplementedError
21012102

@@ -2294,7 +2295,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
22942295

22952296
if isinstance(v, (np.ndarray, Series)):
22962297
if isinstance(v, Series):
2297-
applied_index = self.obj._get_axis(self.axis)
2298+
applied_index = self._selected_obj._get_axis(self.axis)
22982299
all_indexed_same = _all_indexes_same([
22992300
x.index for x in values
23002301
])
@@ -2367,7 +2368,11 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
23672368

23682369
# if we have date/time like in the original, then coerce dates
23692370
# as we are stacking can easily have object dtypes here
2370-
cd = 'coerce' if self.obj.ndim == 2 and self.obj.dtypes.isin(_DATELIKE_DTYPES).any() else True
2371+
if (self._selected_obj.ndim == 2
2372+
and self._selected_obj.dtypes.isin(_DATELIKE_DTYPES).any()):
2373+
cd = 'coerce'
2374+
else:
2375+
cd = True
23712376
return result.convert_objects(convert_dates=cd)
23722377

23732378
else:
@@ -2668,8 +2673,8 @@ def _wrap_agged_blocks(self, blocks):
26682673
return result.convert_objects()
26692674

26702675
def _iterate_column_groupbys(self):
2671-
for i, colname in enumerate(self.obj.columns):
2672-
yield colname, SeriesGroupBy(self.obj.iloc[:, i],
2676+
for i, colname in enumerate(self._selected_obj.columns):
2677+
yield colname, SeriesGroupBy(self._selected_obj.iloc[:, i],
26732678
selection=colname,
26742679
grouper=self.grouper,
26752680
exclusions=self.exclusions)
@@ -2679,7 +2684,7 @@ def _apply_to_column_groupbys(self, func):
26792684
return concat(
26802685
(func(col_groupby) for _, col_groupby
26812686
in self._iterate_column_groupbys()),
2682-
keys=self.obj.columns, axis=1)
2687+
keys=self._selected_obj.columns, axis=1)
26832688

26842689
def ohlc(self):
26852690
"""
@@ -2701,10 +2706,10 @@ def _iterate_slices(self):
27012706
if self.axis == 0:
27022707
# kludge
27032708
if self._selection is None:
2704-
slice_axis = self.obj.items
2709+
slice_axis = self._selected_obj.items
27052710
else:
27062711
slice_axis = self._selection_list
2707-
slicer = lambda x: self.obj[x]
2712+
slicer = lambda x: self._selected_obj[x]
27082713
else:
27092714
raise NotImplementedError
27102715

pandas/tests/test_groupby.py

+38
Original file line numberDiff line numberDiff line change
@@ -3466,6 +3466,44 @@ def test_index_label_overlaps_location(self):
34663466
expected = ser.take([1, 3, 4])
34673467
assert_series_equal(actual, expected)
34683468

3469+
def test_groupby_selection_with_methods(self):
3470+
# some methods which require DatetimeIndex
3471+
rng = pd.date_range('2014', periods=len(self.df))
3472+
self.df.index = rng
3473+
3474+
g = self.df.groupby(['A'])[['C']]
3475+
g_exp = self.df[['C']].groupby(self.df['A'])
3476+
# TODO check groupby with > 1 col ?
3477+
3478+
# methods which are called as .foo()
3479+
methods = ['count',
3480+
'corr',
3481+
'cummax', 'cummin', 'cumprod',
3482+
'describe', 'rank',
3483+
'quantile',
3484+
'diff', 'shift',
3485+
'all', 'any',
3486+
'idxmin', 'idxmax',
3487+
'ffill', 'bfill',
3488+
'pct_change',
3489+
'tshift'
3490+
]
3491+
3492+
for m in methods:
3493+
res = getattr(g, m)()
3494+
exp = getattr(g_exp, m)()
3495+
assert_frame_equal(res, exp) # should always be frames!
3496+
3497+
# methods which aren't just .foo()
3498+
assert_frame_equal(g.fillna(0), g_exp.fillna(0))
3499+
assert_frame_equal(g.dtypes, g_exp.dtypes)
3500+
assert_frame_equal(g.apply(lambda x: x.sum()),
3501+
g_exp.apply(lambda x: x.sum()))
3502+
3503+
assert_frame_equal(g.resample('D'), g_exp.resample('D'))
3504+
3505+
3506+
34693507
def test_groupby_whitelist(self):
34703508
from string import ascii_lowercase
34713509
letters = np.array(list(ascii_lowercase))

0 commit comments

Comments
 (0)