@@ -1008,17 +1008,8 @@ def _concat_objects(
1008
1008
):
1009
1009
from pandas .core .reshape .concat import concat
1010
1010
1011
- def reset_identity (values ):
1012
- # reset the identities of the components
1013
- # of the values to prevent aliasing
1014
- for v in com .not_none (* values ):
1015
- ax = v ._get_axis (self .axis )
1016
- ax ._reset_identity ()
1017
- return values
1018
-
1019
1011
if self .group_keys and not is_transform :
1020
1012
1021
- values = reset_identity (values )
1022
1013
if self .as_index :
1023
1014
1024
1015
# possible MI return case
@@ -1063,7 +1054,6 @@ def reset_identity(values):
1063
1054
result = result .reindex (ax , axis = self .axis , copy = False )
1064
1055
1065
1056
else :
1066
- values = reset_identity (values )
1067
1057
result = concat (values , axis = self .axis )
1068
1058
1069
1059
name = self .obj .name if self .obj .ndim == 1 else self ._selection
@@ -1123,6 +1113,17 @@ def _indexed_output_to_ndframe(
1123
1113
) -> Series | DataFrame :
1124
1114
raise AbstractMethodError (self )
1125
1115
1116
+ @final
1117
+ def _maybe_transpose_result (self , result : NDFrameT ) -> NDFrameT :
1118
+ if self .axis == 1 :
1119
+ # Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
1120
+ result = result .T
1121
+ if result .index .equals (self .obj .index ):
1122
+ # Retain e.g. DatetimeIndex/TimedeltaIndex freq
1123
+ # e.g. test_groupby_crash_on_nunique
1124
+ result .index = self .obj .index .copy ()
1125
+ return result
1126
+
1126
1127
@final
1127
1128
def _wrap_aggregated_output (
1128
1129
self ,
@@ -1160,15 +1161,10 @@ def _wrap_aggregated_output(
1160
1161
1161
1162
result .index = index
1162
1163
1163
- if self .axis == 1 :
1164
- # Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
1165
- result = result .T
1166
- if result .index .equals (self .obj .index ):
1167
- # Retain e.g. DatetimeIndex/TimedeltaIndex freq
1168
- result .index = self .obj .index .copy ()
1169
- # TODO: Do this more systematically
1170
-
1171
- return self ._reindex_output (result , qs = qs )
1164
+ # error: Argument 1 to "_maybe_transpose_result" of "GroupBy" has
1165
+ # incompatible type "Union[Series, DataFrame]"; expected "NDFrameT"
1166
+ res = self ._maybe_transpose_result (result ) # type: ignore[arg-type]
1167
+ return self ._reindex_output (res , qs = qs )
1172
1168
1173
1169
def _wrap_applied_output (
1174
1170
self ,
@@ -1242,17 +1238,18 @@ def _numba_agg_general(
1242
1238
return data ._constructor (result , index = index , ** result_kwargs )
1243
1239
1244
1240
@final
1245
- def _transform_with_numba (
1246
- self , data : DataFrame , func , * args , engine_kwargs = None , ** kwargs
1247
- ):
1241
+ def _transform_with_numba (self , func , * args , engine_kwargs = None , ** kwargs ):
1248
1242
"""
1249
1243
Perform groupby transform routine with the numba engine.
1250
1244
1251
1245
This routine mimics the data splitting routine of the DataSplitter class
1252
1246
to generate the indices of each group in the sorted data and then passes the
1253
1247
data and indices into a Numba jitted function.
1254
1248
"""
1255
- starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1249
+ data = self ._obj_with_exclusions
1250
+ df = data if data .ndim == 2 else data .to_frame ()
1251
+
1252
+ starts , ends , sorted_index , sorted_data = self ._numba_prep (df )
1256
1253
numba_ .validate_udf (func )
1257
1254
numba_transform_func = numba_ .generate_numba_transform_func (
1258
1255
func , ** get_jit_arguments (engine_kwargs , kwargs )
@@ -1262,25 +1259,33 @@ def _transform_with_numba(
1262
1259
sorted_index ,
1263
1260
starts ,
1264
1261
ends ,
1265
- len (data .columns ),
1262
+ len (df .columns ),
1266
1263
* args ,
1267
1264
)
1268
1265
# result values needs to be resorted to their original positions since we
1269
1266
# evaluated the data sorted by group
1270
- return result .take (np .argsort (sorted_index ), axis = 0 )
1267
+ result = result .take (np .argsort (sorted_index ), axis = 0 )
1268
+ index = data .index
1269
+ if data .ndim == 1 :
1270
+ result_kwargs = {"name" : data .name }
1271
+ result = result .ravel ()
1272
+ else :
1273
+ result_kwargs = {"columns" : data .columns }
1274
+ return data ._constructor (result , index = index , ** result_kwargs )
1271
1275
1272
1276
@final
1273
- def _aggregate_with_numba (
1274
- self , data : DataFrame , func , * args , engine_kwargs = None , ** kwargs
1275
- ):
1277
+ def _aggregate_with_numba (self , func , * args , engine_kwargs = None , ** kwargs ):
1276
1278
"""
1277
1279
Perform groupby aggregation routine with the numba engine.
1278
1280
1279
1281
This routine mimics the data splitting routine of the DataSplitter class
1280
1282
to generate the indices of each group in the sorted data and then passes the
1281
1283
data and indices into a Numba jitted function.
1282
1284
"""
1283
- starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1285
+ data = self ._obj_with_exclusions
1286
+ df = data if data .ndim == 2 else data .to_frame ()
1287
+
1288
+ starts , ends , sorted_index , sorted_data = self ._numba_prep (df )
1284
1289
numba_ .validate_udf (func )
1285
1290
numba_agg_func = numba_ .generate_numba_agg_func (
1286
1291
func , ** get_jit_arguments (engine_kwargs , kwargs )
@@ -1290,10 +1295,20 @@ def _aggregate_with_numba(
1290
1295
sorted_index ,
1291
1296
starts ,
1292
1297
ends ,
1293
- len (data .columns ),
1298
+ len (df .columns ),
1294
1299
* args ,
1295
1300
)
1296
- return result
1301
+ index = self .grouper .result_index
1302
+ if data .ndim == 1 :
1303
+ result_kwargs = {"name" : data .name }
1304
+ result = result .ravel ()
1305
+ else :
1306
+ result_kwargs = {"columns" : data .columns }
1307
+ res = data ._constructor (result , index = index , ** result_kwargs )
1308
+ if not self .as_index :
1309
+ res = self ._insert_inaxis_grouper (res )
1310
+ res .index = default_index (len (res ))
1311
+ return res
1297
1312
1298
1313
# -----------------------------------------------------------------
1299
1314
# apply/agg/transform
@@ -1536,19 +1551,9 @@ def _cython_transform(
1536
1551
def _transform (self , func , * args , engine = None , engine_kwargs = None , ** kwargs ):
1537
1552
1538
1553
if maybe_use_numba (engine ):
1539
- data = self ._obj_with_exclusions
1540
- df = data if data .ndim == 2 else data .to_frame ()
1541
- result = self ._transform_with_numba (
1542
- df , func , * args , engine_kwargs = engine_kwargs , ** kwargs
1554
+ return self ._transform_with_numba (
1555
+ func , * args , engine_kwargs = engine_kwargs , ** kwargs
1543
1556
)
1544
- if self .obj .ndim == 2 :
1545
- return cast (DataFrame , self .obj )._constructor (
1546
- result , index = data .index , columns = data .columns
1547
- )
1548
- else :
1549
- return cast (Series , self .obj )._constructor (
1550
- result .ravel (), index = data .index , name = data .name
1551
- )
1552
1557
1553
1558
# optimized transforms
1554
1559
func = com .get_cython_func (func ) or func
0 commit comments