@@ -1219,6 +1219,53 @@ class GroupBy(_GroupBy):
1219
1219
"""
1220
1220
_apply_whitelist = _common_apply_whitelist
1221
1221
1222
+ def _bool_agg (self , val_test , skipna ):
1223
+ """Shared func to call any / all Cython GroupBy implementations"""
1224
+
1225
+ def objs_to_bool (vals ):
1226
+ try :
1227
+ vals = vals .astype (np .bool )
1228
+ except ValueError : # for objects
1229
+ vals = np .array ([bool (x ) for x in vals ])
1230
+
1231
+ return vals .view (np .uint8 )
1232
+
1233
+ def result_to_bool (result ):
1234
+ return result .astype (np .bool , copy = False )
1235
+
1236
+ return self ._get_cythonized_result ('group_any_all' , self .grouper ,
1237
+ aggregate = True ,
1238
+ cython_dtype = np .uint8 ,
1239
+ needs_values = True ,
1240
+ needs_mask = True ,
1241
+ pre_processing = objs_to_bool ,
1242
+ post_processing = result_to_bool ,
1243
+ val_test = val_test , skipna = skipna )
1244
+
1245
+ @Substitution (name = 'groupby' )
1246
+ @Appender (_doc_template )
1247
+ def any (self , skipna = True ):
1248
+ """Returns True if any value in the group is truthful, else False
1249
+
1250
+ Parameters
1251
+ ----------
1252
+ skipna : bool, default True
1253
+ Flag to ignore nan values during truth testing
1254
+ """
1255
+ return self ._bool_agg ('any' , skipna )
1256
+
1257
+ @Substitution (name = 'groupby' )
1258
+ @Appender (_doc_template )
1259
+ def all (self , skipna = True ):
1260
+ """Returns True if all values in the group are truthful, else False
1261
+
1262
+ Parameters
1263
+ ----------
1264
+ skipna : bool, default True
1265
+ Flag to ignore nan values during truth testing
1266
+ """
1267
+ return self ._bool_agg ('all' , skipna )
1268
+
1222
1269
@Substitution (name = 'groupby' )
1223
1270
@Appender (_doc_template )
1224
1271
def count (self ):
@@ -1485,6 +1532,8 @@ def _fill(self, direction, limit=None):
1485
1532
1486
1533
return self ._get_cythonized_result ('group_fillna_indexer' ,
1487
1534
self .grouper , needs_mask = True ,
1535
+ cython_dtype = np .int64 ,
1536
+ result_is_index = True ,
1488
1537
direction = direction , limit = limit )
1489
1538
1490
1539
@Substitution (name = 'groupby' )
@@ -1873,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
1873
1922
1874
1923
return self ._cython_transform ('cummax' , numeric_only = False )
1875
1924
1876
- def _get_cythonized_result (self , how , grouper , needs_mask = False ,
1877
- needs_ngroups = False , ** kwargs ):
1925
+ def _get_cythonized_result (self , how , grouper , aggregate = False ,
1926
+ cython_dtype = None , needs_values = False ,
1927
+ needs_mask = False , needs_ngroups = False ,
1928
+ result_is_index = False ,
1929
+ pre_processing = None , post_processing = None ,
1930
+ ** kwargs ):
1878
1931
"""Get result for Cythonized functions
1879
1932
1880
1933
Parameters
1881
1934
----------
1882
1935
how : str, Cythonized function name to be called
1883
1936
grouper : Grouper object containing pertinent group info
1937
+ aggregate : bool, default False
1938
+ Whether the result should be aggregated to match the number of
1939
+ groups
1940
+ cython_dtype : default None
1941
+ Type of the array that will be modified by the Cython call. If
1942
+ `None`, the type will be inferred from the values of each slice
1943
+ needs_values : bool, default False
1944
+ Whether the values should be a part of the Cython call
1945
+ signature
1884
1946
needs_mask : bool, default False
1885
- Whether boolean mask needs to be part of the Cython call signature
1947
+ Whether boolean mask needs to be part of the Cython call
1948
+ signature
1886
1949
needs_ngroups : bool, default False
1887
- Whether number of groups part of the Cython call signature
1950
+ Whether number of groups is part of the Cython call signature
1951
+ result_is_index : bool, default False
1952
+ Whether the result of the Cython operation is an index of
1953
+ values to be retrieved, instead of the actual values themselves
1954
+ pre_processing : function, default None
1955
+ Function to be applied to `values` prior to passing to Cython
1956
+ Raises if `needs_values` is False
1957
+ post_processing : function, default None
1958
+ Function to be applied to result of Cython function
1888
1959
**kwargs : dict
1889
1960
Extra arguments to be passed back to Cython funcs
1890
1961
1891
1962
Returns
1892
1963
-------
1893
1964
`Series` or `DataFrame` with filled values
1894
1965
"""
1966
+ if result_is_index and aggregate :
1967
+ raise ValueError ("'result_is_index' and 'aggregate' cannot both "
1968
+ "be True!" )
1969
+ if post_processing :
1970
+ if not callable (pre_processing ):
1971
+ raise ValueError ("'post_processing' must be a callable!" )
1972
+ if pre_processing :
1973
+ if not callable (pre_processing ):
1974
+ raise ValueError ("'pre_processing' must be a callable!" )
1975
+ if not needs_values :
1976
+ raise ValueError ("Cannot use 'pre_processing' without "
1977
+ "specifying 'needs_values'!" )
1895
1978
1896
1979
labels , _ , ngroups = grouper .group_info
1897
1980
output = collections .OrderedDict ()
1898
1981
base_func = getattr (libgroupby , how )
1899
1982
1900
1983
for name , obj in self ._iterate_slices ():
1901
- indexer = np .zeros_like (labels , dtype = np .int64 )
1902
- func = partial (base_func , indexer , labels )
1984
+ if aggregate :
1985
+ result_sz = ngroups
1986
+ else :
1987
+ result_sz = len (obj .values )
1988
+
1989
+ if not cython_dtype :
1990
+ cython_dtype = obj .values .dtype
1991
+
1992
+ result = np .zeros (result_sz , dtype = cython_dtype )
1993
+ func = partial (base_func , result , labels )
1994
+ if needs_values :
1995
+ vals = obj .values
1996
+ if pre_processing :
1997
+ vals = pre_processing (vals )
1998
+ func = partial (func , vals )
1999
+
1903
2000
if needs_mask :
1904
2001
mask = isnull (obj .values ).view (np .uint8 )
1905
2002
func = partial (func , mask )
@@ -1908,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
1908
2005
func = partial (func , ngroups )
1909
2006
1910
2007
func (** kwargs ) # Call func to modify indexer values in place
1911
- output [name ] = algorithms .take_nd (obj .values , indexer )
1912
2008
1913
- return self ._wrap_transformed_output (output )
2009
+ if result_is_index :
2010
+ result = algorithms .take_nd (obj .values , result )
2011
+
2012
+ if post_processing :
2013
+ result = post_processing (result )
2014
+
2015
+ output [name ] = result
2016
+
2017
+ if aggregate :
2018
+ return self ._wrap_aggregated_output (output )
2019
+ else :
2020
+ return self ._wrap_transformed_output (output )
1914
2021
1915
2022
@Substitution (name = 'groupby' )
1916
2023
@Appender (_doc_template )
@@ -1930,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
1930
2037
return self .apply (lambda x : x .shift (periods , freq , axis ))
1931
2038
1932
2039
return self ._get_cythonized_result ('group_shift_indexer' ,
1933
- self .grouper , needs_ngroups = True ,
2040
+ self .grouper , cython_dtype = np .int64 ,
2041
+ needs_ngroups = True ,
2042
+ result_is_index = True ,
1934
2043
periods = periods )
1935
2044
1936
2045
@Substitution (name = 'groupby' )
0 commit comments