@@ -5141,6 +5141,118 @@ def test_cumcount_groupby_not_col(self):
5141
5141
assert_series_equal (expected , g .cumcount ())
5142
5142
assert_series_equal (expected , sg .cumcount ())
5143
5143
5144
+ def test_enumerate (self ):
5145
+ df = DataFrame ([['a' ], ['a' ], ['a' ], ['b' ], ['a' ]], columns = ['A' ])
5146
+ g = df .groupby ('A' )
5147
+ sg = g .A
5148
+
5149
+ expected = Series ([0 , 0 , 0 , 1 , 0 ])
5150
+
5151
+ assert_series_equal (expected , g .enumerate ())
5152
+ assert_series_equal (expected , sg .enumerate ())
5153
+
5154
+ def test_enumerate_empty (self ):
5155
+ ge = DataFrame ().groupby (level = 0 )
5156
+ se = Series ().groupby (level = 0 )
5157
+
5158
+ # edge case, as this is usually considered float
5159
+ e = Series (dtype = 'int64' )
5160
+
5161
+ assert_series_equal (e , ge .enumerate ())
5162
+ assert_series_equal (e , se .enumerate ())
5163
+
5164
+ def test_enumerate_dupe_index (self ):
5165
+ df = DataFrame ([['a' ], ['a' ], ['a' ], ['b' ], ['a' ]], columns = ['A' ],
5166
+ index = [0 ] * 5 )
5167
+ g = df .groupby ('A' )
5168
+ sg = g .A
5169
+
5170
+ expected = Series ([0 , 0 , 0 , 1 , 0 ], index = [0 ] * 5 )
5171
+
5172
+ assert_series_equal (expected , g .enumerate ())
5173
+ assert_series_equal (expected , sg .enumerate ())
5174
+
5175
+ def test_enumerate_mi (self ):
5176
+ mi = MultiIndex .from_tuples ([[0 , 1 ], [1 , 2 ], [2 , 2 ], [2 , 2 ], [1 , 0 ]])
5177
+ df = DataFrame ([['a' ], ['a' ], ['a' ], ['b' ], ['a' ]], columns = ['A' ],
5178
+ index = mi )
5179
+ g = df .groupby ('A' )
5180
+ sg = g .A
5181
+
5182
+ expected = Series ([0 , 0 , 0 , 1 , 0 ], index = mi )
5183
+
5184
+ assert_series_equal (expected , g .enumerate ())
5185
+ assert_series_equal (expected , sg .enumerate ())
5186
+
5187
+ def test_enumerate_groupby_not_col (self ):
5188
+ df = DataFrame ([['a' ], ['a' ], ['a' ], ['b' ], ['a' ]], columns = ['A' ],
5189
+ index = [0 ] * 5 )
5190
+ g = df .groupby ([0 , 0 , 0 , 1 , 0 ])
5191
+ sg = g .A
5192
+
5193
+ expected = Series ([0 , 0 , 0 , 1 , 0 ], index = [0 ] * 5 )
5194
+
5195
+ assert_series_equal (expected , g .enumerate ())
5196
+ assert_series_equal (expected , sg .enumerate ())
5197
+
5198
+ def test_enumerate_descending (self ):
5199
+ df = DataFrame (['a' , 'a' , 'b' , 'a' , 'b' ], columns = ['A' ])
5200
+ g = df .groupby (['A' ])
5201
+
5202
+ ascending = Series ([0 , 0 , 1 , 0 , 1 ])
5203
+ descending = Series ([1 , 1 , 0 , 1 , 0 ])
5204
+
5205
+ assert_series_equal (descending , (g .ngroups - 1 ) - ascending )
5206
+ assert_series_equal (ascending , g .enumerate (ascending = True ))
5207
+ assert_series_equal (descending , g .enumerate (ascending = False ))
5208
+
5209
+ def test_enumerate_matches_cumcount (self ):
5210
+ # specific case
5211
+ df = DataFrame ([['a' , 'x' ], ['a' , 'y' ], ['b' , 'x' ],
5212
+ ['a' , 'x' ], ['b' , 'y' ]], columns = ['A' , 'X' ])
5213
+ g = df .groupby (['A' , 'X' ])
5214
+
5215
+ g_enumerate = g .enumerate ()
5216
+ g_cumcount = g .cumcount ()
5217
+ expected_enumerate = pd .Series ([0 , 1 , 2 , 0 , 3 ])
5218
+ expected_cumcount = pd .Series ([0 , 0 , 0 , 1 , 0 ])
5219
+
5220
+ assert_series_equal (g_enumerate , expected_enumerate )
5221
+ assert_series_equal (g_cumcount , expected_cumcount )
5222
+
5223
+ def test_enumerate_cumcount_pair (self ):
5224
+ from itertools import product
5225
+
5226
+ # brute force comparison, inefficient but clear
5227
+ for p in product (range (3 ), repeat = 4 ):
5228
+ df = DataFrame ({'a' : p })
5229
+ g = df .groupby (['a' ])
5230
+
5231
+ order = sorted (set (p ))
5232
+ enumerated = [order .index (val ) for val in p ]
5233
+ cumcounted = [p [:i ].count (val ) for i , val in enumerate (p )]
5234
+
5235
+ assert_series_equal (g .enumerate (), pd .Series (enumerated ))
5236
+ assert_series_equal (g .cumcount (), pd .Series (cumcounted ))
5237
+
5238
+ def test_enumerate_respects_groupby_order (self ):
5239
+ np .random .seed (0 )
5240
+ df = DataFrame ({'a' : np .random .choice (list ('abcdef' ), 100 )})
5241
+ for sort_flag in (False , True ):
5242
+ g = df .groupby (['a' ], sort = sort_flag )
5243
+ df ['group_id' ] = - 1
5244
+ df ['group_index' ] = - 1
5245
+
5246
+ for i , (key , group ) in enumerate (g ):
5247
+ df .loc [group .index , 'group_id' ] = i
5248
+ for j , ind in enumerate (group .index ):
5249
+ df .loc [ind , 'group_index' ] = j
5250
+
5251
+ assert_series_equal (pd .Series (df ['group_id' ].values ),
5252
+ g .enumerate ())
5253
+ assert_series_equal (pd .Series (df ['group_index' ].values ),
5254
+ g .cumcount ())
5255
+
5144
5256
def test_filter_series (self ):
5145
5257
s = pd .Series ([1 , 3 , 20 , 5 , 22 , 24 , 7 ])
5146
5258
expected_odd = pd .Series ([1 , 3 , 5 , 7 ], index = [0 , 1 , 3 , 6 ])
0 commit comments