Skip to content

Commit 6fa398e

Browse files
committed
ENH: infer selection_obj on groupby with an applied method (GH5610)
1 parent ccab72d commit 6fa398e

File tree

5 files changed

+117
-36
lines changed

5 files changed

+117
-36
lines changed

doc/source/release.rst

+2
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ API Changes
179179
validation warnings in :func:`read_csv`/:func:`read_table` (:issue:`6607`)
180180
- Raise a ``TypeError`` when ``DataFrame`` is passed an iterator as the
181181
``data`` argument (:issue:`5357`)
182+
- groupby will now not return the grouped column for non-cython functions (:issue:`5610`),
183+
as its already the index
182184

183185
Deprecations
184186
~~~~~~~~~~~~

doc/source/v0.14.0.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,22 @@ API changes
110110

111111
.. ipython:: python
112112

113-
DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
113+
df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
114114
g = df.groupby('A')
115115
g.nth(0) # can also use negative ints
116116

117117
g.nth(0, dropna='any') # similar to old behaviour
118118

119+
groupby will now not return the grouped column for non-cython functions (:issue:`5610`),
120+
as its already the index
121+
122+
.. ipython:: python
123+
124+
df = DataFrame([[1, np.nan], [1, 4], [5, 6], [5, 8]], columns=['A', 'B'])
125+
g = df.groupby('A')
126+
g.count()
127+
g.describe()
128+
119129
- Allow specification of a more complex groupby via ``pd.Grouper``, such as grouping
120130
by a Time and a string field simultaneously. See :ref:`the docs <groupby.specify>`. (:issue:`3794`)
121131

pandas/core/groupby.py

+52-26
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,23 @@ def _selection_list(self):
445445
return [self._selection]
446446
return self._selection
447447

448+
@cache_readonly
449+
def _selected_obj(self):
450+
451+
if self._selection is None or isinstance(self.obj, Series):
452+
return self.obj
453+
else:
454+
return self.obj[self._selection]
455+
456+
def _set_selection_from_grouper(self):
457+
""" we may need create a selection if we have non-level groupers """
458+
grp = self.grouper
459+
if self._selection is None and getattr(grp,'groupings',None) is not None:
460+
ax = self.obj._info_axis
461+
groupers = [ g.name for g in grp.groupings if g.level is None and g.name is not None and g.name in ax ]
462+
if len(groupers):
463+
self._selection = (ax-Index(groupers)).tolist()
464+
448465
def _local_dir(self):
449466
return sorted(set(self.obj._local_dir() + list(self._apply_whitelist)))
450467

@@ -453,7 +470,6 @@ def __getattr__(self, attr):
453470
return object.__getattribute__(self, attr)
454471
if attr in self.obj:
455472
return self[attr]
456-
457473
if hasattr(self.obj, attr):
458474
return self._make_wrapper(attr)
459475

@@ -472,6 +488,10 @@ def _make_wrapper(self, name):
472488
type(self).__name__))
473489
raise AttributeError(msg)
474490

491+
# need to setup the selection
492+
# as are not passed directly but in the grouper
493+
self._set_selection_from_grouper()
494+
475495
f = getattr(self._selected_obj, name)
476496
if not isinstance(f, types.MethodType):
477497
return self.apply(lambda self: getattr(self, name))
@@ -503,7 +523,19 @@ def curried(x):
503523
try:
504524
return self.apply(curried_with_axis)
505525
except Exception:
506-
return self.apply(curried)
526+
try:
527+
return self.apply(curried)
528+
except Exception:
529+
530+
# related to : GH3688
531+
# try item-by-item
532+
# this can be called recursively, so need to raise ValueError if
533+
# we don't have this method to indicated to aggregate to
534+
# mark this column as an error
535+
try:
536+
return self._aggregate_item_by_item(name, *args, **kwargs)
537+
except (AttributeError):
538+
raise ValueError
507539

508540
return wrapper
509541

@@ -624,6 +656,7 @@ def mean(self):
624656
except GroupByError:
625657
raise
626658
except Exception: # pragma: no cover
659+
self._set_selection_from_grouper()
627660
f = lambda x: x.mean(axis=self.axis)
628661
return self._python_agg_general(f)
629662

@@ -639,6 +672,7 @@ def median(self):
639672
raise
640673
except Exception: # pragma: no cover
641674

675+
self._set_selection_from_grouper()
642676
def f(x):
643677
if isinstance(x, np.ndarray):
644678
x = Series(x)
@@ -655,6 +689,7 @@ def std(self, ddof=1):
655689
if ddof == 1:
656690
return self._cython_agg_general('std')
657691
else:
692+
self._set_selection_from_grouper()
658693
f = lambda x: x.std(ddof=ddof)
659694
return self._python_agg_general(f)
660695

@@ -667,6 +702,7 @@ def var(self, ddof=1):
667702
if ddof == 1:
668703
return self._cython_agg_general('var')
669704
else:
705+
self._set_selection_from_grouper()
670706
f = lambda x: x.var(ddof=ddof)
671707
return self._python_agg_general(f)
672708

@@ -677,12 +713,14 @@ def size(self):
677713
"""
678714
return self.grouper.size()
679715

680-
def count(self):
716+
def count(self, axis=0):
681717
"""
682718
Number of non-null items in each group.
683-
719+
axis : axis number, default 0
720+
the grouping axis
684721
"""
685-
return self._python_agg_general(lambda x: notnull(x).sum())
722+
self._set_selection_from_grouper()
723+
return self._python_agg_general(lambda x: notnull(x).sum(axis=axis)).astype('int64')
686724

687725
sum = _groupby_function('sum', 'add', np.sum)
688726
prod = _groupby_function('prod', 'prod', np.prod)
@@ -693,12 +731,14 @@ def count(self):
693731
last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
694732
_convert=True)
695733

734+
696735
def ohlc(self):
697736
"""
698-
Deprecated, use .resample(how="ohlc") instead.
699-
737+
Compute sum of values, excluding missing values
738+
For multiple groupings, the result index will be a MultiIndex
700739
"""
701-
raise AttributeError('ohlc is deprecated, use resample(how="ohlc").')
740+
return self._apply_to_column_groupbys(
741+
lambda x: x._cython_agg_general('ohlc'))
702742

703743
def nth(self, n, dropna=None):
704744
"""
@@ -894,13 +934,6 @@ def _cumcount_array(self, arr=None, **kwargs):
894934
cumcounts[v] = arr[len(v)-1::-1]
895935
return cumcounts
896936

897-
@cache_readonly
898-
def _selected_obj(self):
899-
if self._selection is None or isinstance(self.obj, Series):
900-
return self.obj
901-
else:
902-
return self.obj[self._selection]
903-
904937
def _index_with_as_index(self, b):
905938
"""
906939
Take boolean mask of index to be returned from apply, if as_index=True
@@ -945,7 +978,6 @@ def _cython_agg_general(self, how, numeric_only=True):
945978
result, names = self.grouper.aggregate(obj.values, how)
946979
except AssertionError as e:
947980
raise GroupByError(str(e))
948-
# infer old dytpe
949981
output[name] = self._try_cast(result, obj)
950982

951983
if len(output) == 0:
@@ -954,8 +986,6 @@ def _cython_agg_general(self, how, numeric_only=True):
954986
return self._wrap_aggregated_output(output, names)
955987

956988
def _python_agg_general(self, func, *args, **kwargs):
957-
_dtype = kwargs.pop("_dtype", None)
958-
959989
func = _intercept_function(func)
960990
f = lambda x: func(x, *args, **kwargs)
961991

@@ -964,14 +994,7 @@ def _python_agg_general(self, func, *args, **kwargs):
964994
for name, obj in self._iterate_slices():
965995
try:
966996
result, counts = self.grouper.agg_series(obj, f)
967-
968-
if _dtype is None: # infer old dytpe
969-
output[name] = self._try_cast(result, obj)
970-
elif _dtype is False:
971-
output[name] = result
972-
else:
973-
output[name] = _possibly_downcast_to_dtype(result, _dtype)
974-
997+
output[name] = self._try_cast(result, obj)
975998
except TypeError:
976999
continue
9771000

@@ -2203,6 +2226,9 @@ def true_and_notnull(x, *args, **kwargs):
22032226
filtered = self._apply_filter(indices, dropna)
22042227
return filtered
22052228

2229+
def _apply_to_column_groupbys(self, func):
2230+
""" return a pass thru """
2231+
return func(self)
22062232

22072233
class NDFrameGroupBy(GroupBy):
22082234

pandas/tests/test_groupby.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -1971,16 +1971,53 @@ def test_size(self):
19711971
self.assertEquals(result[key], len(group))
19721972

19731973
def test_count(self):
1974-
df = pd.DataFrame([[1, 2], [1, nan], [3, nan]], columns=['A', 'B'])
1974+
1975+
# GH5610
1976+
# count counts non-nulls
1977+
df = pd.DataFrame([[1, 2, 'foo'], [1, nan, 'bar'], [3, nan, nan]], columns=['A', 'B', 'C'])
1978+
19751979
count_as = df.groupby('A').count()
19761980
count_not_as = df.groupby('A', as_index=False).count()
19771981

1978-
res = pd.DataFrame([[1, 1], [3, 0]], columns=['A', 'B'])
1979-
assert_frame_equal(count_not_as, res)
1980-
assert_frame_equal(count_as, res.set_index('A'))
1982+
expected = DataFrame([[1, 2], [0, 0]], columns=['B', 'C'], index=[1,3])
1983+
expected.index.name='A'
1984+
assert_frame_equal(count_not_as, expected.reset_index())
1985+
assert_frame_equal(count_as, expected)
19811986

19821987
count_B = df.groupby('A')['B'].count()
1983-
assert_series_equal(count_B, res['B'])
1988+
assert_series_equal(count_B, expected['B'])
1989+
1990+
def test_non_cython_api(self):
1991+
1992+
# GH5610
1993+
# non-cython calls should not include the grouper
1994+
1995+
df = DataFrame([[1, 2, 'foo'], [1, nan, 'bar',], [3, nan, 'baz']], columns=['A', 'B','C'])
1996+
g = df.groupby('A')
1997+
1998+
# mad
1999+
expected = DataFrame([[0],[nan]],columns=['B'],index=[1,3])
2000+
expected.index.name = 'A'
2001+
result = g.mad()
2002+
assert_frame_equal(result,expected)
2003+
2004+
# describe
2005+
expected = DataFrame(dict(B = concat([df.loc[[0,1],'B'].describe(),df.loc[[2],'B'].describe()],keys=[1,3])))
2006+
expected.index.names = ['A',None]
2007+
result = g.describe()
2008+
assert_frame_equal(result,expected)
2009+
2010+
# any
2011+
expected = DataFrame([[True, True],[False, True]],columns=['B','C'],index=[1,3])
2012+
expected.index.name = 'A'
2013+
result = g.any()
2014+
assert_frame_equal(result,expected)
2015+
2016+
# idxmax
2017+
expected = DataFrame([[0],[nan]],columns=['B'],index=[1,3])
2018+
expected.index.name = 'A'
2019+
result = g.idxmax()
2020+
assert_frame_equal(result,expected)
19842021

19852022
def test_grouping_ndarray(self):
19862023
grouped = self.df.groupby(self.df['A'].values)
@@ -2937,7 +2974,7 @@ def test_groupby_with_timegrouper(self):
29372974
DT.datetime(2013,12,2,12,0),
29382975
DT.datetime(2013,9,2,14,0),
29392976
]})
2940-
2977+
29412978
# GH 6908 change target column's order
29422979
df_reordered = df_original.sort(columns='Quantity')
29432980

@@ -3949,8 +3986,14 @@ def test_frame_groupby_plot_boxplot(self):
39493986
self.assertEqual(len(res), 2)
39503987
tm.close()
39513988

3989+
# now works with GH 5610 as gender is excluded
3990+
res = df.groupby('gender').hist()
3991+
tm.close()
3992+
3993+
df2 = df.copy()
3994+
df2['gender2'] = df['gender']
39523995
with tm.assertRaisesRegexp(TypeError, '.*str.+float'):
3953-
gb.hist()
3996+
df2.groupby('gender').hist()
39543997

39553998
@slow
39563999
def test_frame_groupby_hist(self):

pandas/tseries/tests/test_resample.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1126,9 +1126,9 @@ def test_evenly_divisible_with_no_extra_bins(self):
11261126
expected = DataFrame(
11271127
[{'REST_KEY': 14, 'DLY_TRN_QT': 14, 'DLY_SLS_AMT': 14,
11281128
'COOP_DLY_TRN_QT': 14, 'COOP_DLY_SLS_AMT': 14}] * 4,
1129-
index=index).unstack().swaplevel(1,0).sortlevel()
1129+
index=index)
11301130
result = df.resample('7D', how='count')
1131-
assert_series_equal(result,expected)
1131+
assert_frame_equal(result,expected)
11321132

11331133
expected = DataFrame(
11341134
[{'REST_KEY': 21, 'DLY_TRN_QT': 1050, 'DLY_SLS_AMT': 700,

0 commit comments

Comments
 (0)