@@ -1219,6 +1219,29 @@ class GroupBy(_GroupBy):
1219
1219
"""
1220
1220
_apply_whitelist = _common_apply_whitelist
1221
1221
1222
+ def _bool_agg (self , how , skipna ):
1223
+ """Shared func to call any / all Cython GroupBy implementations"""
1224
+
1225
+ def objs_to_bool (vals ):
1226
+ try :
1227
+ vals = vals .astype (np .bool )
1228
+ except ValueError : # for objects
1229
+ vals = np .array ([bool (x ) for x in vals ])
1230
+
1231
+ return vals .view (np .uint8 )
1232
+
1233
+ def result_to_bool (result ):
1234
+ return result .astype (np .bool , copy = False )
1235
+
1236
+ return self ._get_cythonized_result (how , self .grouper ,
1237
+ aggregate = True ,
1238
+ cython_dtype = np .uint8 ,
1239
+ needs_values = True ,
1240
+ needs_mask = True ,
1241
+ pre_processing = objs_to_bool ,
1242
+ post_processing = result_to_bool ,
1243
+ skipna = skipna )
1244
+
1222
1245
@Substitution (name = 'groupby' )
1223
1246
@Appender (_doc_template )
1224
1247
def any (self , skipna = True ):
@@ -1229,15 +1252,19 @@ def any(self, skipna=True):
1229
1252
skipna : bool, default True
1230
1253
Flag to ignore nan values during truth testing
1231
1254
"""
1232
- labels , _ , _ = self .grouper .group_info
1233
- output = collections .OrderedDict ()
1255
+ return self ._bool_agg ('group_any' , skipna )
1234
1256
1235
- for name , obj in self . _iterate_slices ():
1236
- result = np . zeros ( self . ngroups , dtype = np . int64 )
1237
- libgroupby . group_any ( result , obj . values , labels , skipna )
1238
- output [ name ] = result . astype ( np . bool )
1257
+ @ Substitution ( name = 'groupby' )
1258
+ @ Appender ( _doc_template )
1259
+ def all ( self , skipna = True ):
1260
+ """Returns True if all values in the group are truthful, else False
1239
1261
1240
- return self ._wrap_aggregated_output (output )
1262
+ Parameters
1263
+ ----------
1264
+ skipna : bool, default True
1265
+ Flag to ignore nan values during truth testing
1266
+ """
1267
+ return self ._bool_agg ('group_all' , skipna )
1241
1268
1242
1269
@Substitution (name = 'groupby' )
1243
1270
@Appender (_doc_template )
@@ -1505,6 +1532,8 @@ def _fill(self, direction, limit=None):
1505
1532
1506
1533
return self ._get_cythonized_result ('group_fillna_indexer' ,
1507
1534
self .grouper , needs_mask = True ,
1535
+ cython_dtype = np .int64 ,
1536
+ result_is_index = True ,
1508
1537
direction = direction , limit = limit )
1509
1538
1510
1539
@Substitution (name = 'groupby' )
@@ -1893,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
1893
1922
1894
1923
return self ._cython_transform ('cummax' , numeric_only = False )
1895
1924
1896
- def _get_cythonized_result (self , how , grouper , needs_mask = False ,
1897
- needs_ngroups = False , ** kwargs ):
1925
+ def _get_cythonized_result (self , how , grouper , aggregate = False ,
1926
+ cython_dtype = None , needs_values = False ,
1927
+ needs_mask = False , needs_ngroups = False ,
1928
+ result_is_index = False ,
1929
+ pre_processing = None , post_processing = None ,
1930
+ ** kwargs ):
1898
1931
"""Get result for Cythonized functions
1899
1932
1900
1933
Parameters
1901
1934
----------
1902
1935
how : str, Cythonized function name to be called
1903
1936
grouper : Grouper object containing pertinent group info
1937
+ aggregate : bool, default False
1938
+ Whether the result should be aggregated to match the number of
1939
+ groups
1940
+ cython_dtype : default None
1941
+ Type of the array that will be modified by the Cython call. If
1942
+ `None`, the type will be inferred from the values of each slice
1943
+ needs_values : bool, default False
1944
+ Whether the values should be a part of the Cython call
1945
+ signature
1904
1946
needs_mask : bool, default False
1905
- Whether boolean mask needs to be part of the Cython call signature
1947
+ Whether boolean mask needs to be part of the Cython call
1948
+ signature
1906
1949
needs_ngroups : bool, default False
1907
- Whether number of groups part of the Cython call signature
1950
+ Whether number of groups is part of the Cython call signature
1951
+ result_is_index : bool, default False
1952
+ Whether the result of the Cython operation is an index of
1953
+ values to be retrieved, instead of the actual values themselves
1954
+ pre_processing : function, default None
1955
+ Function to be applied to `values` prior to passing to Cython
1956
+ Raises if `needs_values` is False
1957
+ post_processing : function, default None
1958
+ Function to be applied to result of Cython function
1908
1959
**kwargs : dict
1909
1960
Extra arguments to be passed back to Cython funcs
1910
1961
1911
1962
Returns
1912
1963
-------
1913
1964
`Series` or `DataFrame` with filled values
1914
1965
"""
1966
+ if result_is_index and aggregate :
1967
+ raise ValueError ("'result_is_index' and 'aggregate' cannot both "
1968
+ "be True!" )
1969
+ if post_processing :
1970
+ if not callable (pre_processing ):
1971
+ raise ValueError ("'post_processing' must be a callable!" )
1972
+ if pre_processing :
1973
+ if not callable (pre_processing ):
1974
+ raise ValueError ("'pre_processing' must be a callable!" )
1975
+ if not needs_values :
1976
+ raise ValueError ("Cannot use 'pre_processing' without "
1977
+ "specifying 'needs_values'!" )
1915
1978
1916
1979
labels , _ , ngroups = grouper .group_info
1917
1980
output = collections .OrderedDict ()
1918
1981
base_func = getattr (libgroupby , how )
1919
1982
1920
1983
for name , obj in self ._iterate_slices ():
1921
- indexer = np .zeros_like (labels , dtype = np .int64 )
1922
- func = partial (base_func , indexer , labels )
1984
+ if aggregate :
1985
+ result_sz = ngroups
1986
+ else :
1987
+ result_sz = len (obj .values )
1988
+
1989
+ if not cython_dtype :
1990
+ cython_dtype = obj .values .dtype
1991
+
1992
+ result = np .zeros (result_sz , dtype = cython_dtype )
1993
+ func = partial (base_func , result , labels )
1994
+ if needs_values :
1995
+ vals = obj .values
1996
+ if pre_processing :
1997
+ vals = pre_processing (vals )
1998
+ func = partial (func , vals )
1999
+
1923
2000
if needs_mask :
1924
2001
mask = isnull (obj .values ).view (np .uint8 )
1925
2002
func = partial (func , mask )
@@ -1928,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
1928
2005
func = partial (func , ngroups )
1929
2006
1930
2007
func (** kwargs ) # Call func to modify indexer values in place
1931
- output [name ] = algorithms .take_nd (obj .values , indexer )
1932
2008
1933
- return self ._wrap_transformed_output (output )
2009
+ if result_is_index :
2010
+ result = algorithms .take_nd (obj .values , result )
2011
+
2012
+ if post_processing :
2013
+ result = post_processing (result )
2014
+
2015
+ output [name ] = result
2016
+
2017
+ if aggregate :
2018
+ return self ._wrap_aggregated_output (output )
2019
+ else :
2020
+ return self ._wrap_transformed_output (output )
1934
2021
1935
2022
@Substitution (name = 'groupby' )
1936
2023
@Appender (_doc_template )
@@ -1950,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
1950
2037
return self .apply (lambda x : x .shift (periods , freq , axis ))
1951
2038
1952
2039
return self ._get_cythonized_result ('group_shift_indexer' ,
1953
- self .grouper , needs_ngroups = True ,
2040
+ self .grouper , cython_dtype = np .int64 ,
2041
+ needs_ngroups = True ,
2042
+ result_is_index = True ,
1954
2043
periods = periods )
1955
2044
1956
2045
@Substitution (name = 'groupby' )
0 commit comments