Skip to content

Commit 5c76f33

Browse files
committed
Revert "Consolidate nth / last object Groupby Implementations (#19610)"
This reverts commit d4730e6.
1 parent d4730e6 commit 5c76f33

File tree

3 files changed

+140
-47
lines changed

3 files changed

+140
-47
lines changed

pandas/_libs/groupby.pyx

+99
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,105 @@ cdef double NaN = <double> np.NaN
2626
cdef double nan = NaN
2727

2828

29+
# TODO: aggregate multiple columns in single pass
30+
# ----------------------------------------------------------------------
31+
# first, nth, last
32+
33+
34+
@cython.boundscheck(False)
35+
@cython.wraparound(False)
36+
def group_nth_object(ndarray[object, ndim=2] out,
37+
ndarray[int64_t] counts,
38+
ndarray[object, ndim=2] values,
39+
ndarray[int64_t] labels,
40+
int64_t rank,
41+
Py_ssize_t min_count=-1):
42+
"""
43+
Only aggregates on axis=0
44+
"""
45+
cdef:
46+
Py_ssize_t i, j, N, K, lab
47+
object val
48+
float64_t count
49+
ndarray[int64_t, ndim=2] nobs
50+
ndarray[object, ndim=2] resx
51+
52+
assert min_count == -1, "'min_count' only used in add and prod"
53+
54+
nobs = np.zeros((<object> out).shape, dtype=np.int64)
55+
resx = np.empty((<object> out).shape, dtype=object)
56+
57+
N, K = (<object> values).shape
58+
59+
for i in range(N):
60+
lab = labels[i]
61+
if lab < 0:
62+
continue
63+
64+
counts[lab] += 1
65+
for j in range(K):
66+
val = values[i, j]
67+
68+
# not nan
69+
if val == val:
70+
nobs[lab, j] += 1
71+
if nobs[lab, j] == rank:
72+
resx[lab, j] = val
73+
74+
for i in range(len(counts)):
75+
for j in range(K):
76+
if nobs[i, j] == 0:
77+
out[i, j] = <object> nan
78+
else:
79+
out[i, j] = resx[i, j]
80+
81+
82+
@cython.boundscheck(False)
83+
@cython.wraparound(False)
84+
def group_last_object(ndarray[object, ndim=2] out,
85+
ndarray[int64_t] counts,
86+
ndarray[object, ndim=2] values,
87+
ndarray[int64_t] labels,
88+
Py_ssize_t min_count=-1):
89+
"""
90+
Only aggregates on axis=0
91+
"""
92+
cdef:
93+
Py_ssize_t i, j, N, K, lab
94+
object val
95+
float64_t count
96+
ndarray[object, ndim=2] resx
97+
ndarray[int64_t, ndim=2] nobs
98+
99+
assert min_count == -1, "'min_count' only used in add and prod"
100+
101+
nobs = np.zeros((<object> out).shape, dtype=np.int64)
102+
resx = np.empty((<object> out).shape, dtype=object)
103+
104+
N, K = (<object> values).shape
105+
106+
for i in range(N):
107+
lab = labels[i]
108+
if lab < 0:
109+
continue
110+
111+
counts[lab] += 1
112+
for j in range(K):
113+
val = values[i, j]
114+
115+
# not nan
116+
if val == val:
117+
nobs[lab, j] += 1
118+
resx[lab, j] = val
119+
120+
for i in range(len(counts)):
121+
for j in range(K):
122+
if nobs[i, j] == 0:
123+
out[i, j] = nan
124+
else:
125+
out[i, j] = resx[i, j]
126+
127+
29128
cdef inline float64_t median_linear(float64_t* a, int n) nogil:
30129
cdef int i, j, na_count = 0
31130
cdef float64_t result

pandas/_libs/groupby_helper.pxi.in

+12-20
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,7 @@ def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
325325
# name, c_type, dest_type2, nan_val
326326
dtypes = [('float64', 'float64_t', 'float64_t', 'NAN'),
327327
('float32', 'float32_t', 'float32_t', 'NAN'),
328-
('int64', 'int64_t', 'int64_t', 'iNaT'),
329-
('object', 'object', 'object', 'NAN')]
328+
('int64', 'int64_t', 'int64_t', 'iNaT')]
330329

331330
def get_dispatch(dtypes):
332331

@@ -351,7 +350,7 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
351350
"""
352351
cdef:
353352
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
354-
{{dest_type2}} val
353+
{{dest_type2}} val, count
355354
ndarray[{{dest_type2}}, ndim=2] resx
356355
ndarray[int64_t, ndim=2] nobs
357356

@@ -361,19 +360,11 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
361360
raise AssertionError("len(index) != len(labels)")
362361

363362
nobs = np.zeros((<object> out).shape, dtype=np.int64)
364-
{{if name=='object'}}
365-
resx = np.empty((<object> out).shape, dtype=object)
366-
{{else}}
367363
resx = np.empty_like(out)
368-
{{endif}}
369364

370365
N, K = (<object> values).shape
371366

372-
{{if name == "object"}}
373-
if True: # make templating happy
374-
{{else}}
375367
with nogil:
376-
{{endif}}
377368
for i in range(N):
378369
lab = labels[i]
379370
if lab < 0:
@@ -384,7 +375,11 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
384375
val = values[i, j]
385376

386377
# not nan
378+
{{if name == 'int64'}}
379+
if val != {{nan_val}}:
380+
{{else}}
387381
if val == val and val != {{nan_val}}:
382+
{{endif}}
388383
nobs[lab, j] += 1
389384
resx[lab, j] = val
390385

@@ -395,6 +390,7 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
395390
else:
396391
out[i, j] = resx[i, j]
397392

393+
398394
@cython.wraparound(False)
399395
@cython.boundscheck(False)
400396
def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
@@ -407,7 +403,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
407403
"""
408404
cdef:
409405
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
410-
{{dest_type2}} val
406+
{{dest_type2}} val, count
411407
ndarray[{{dest_type2}}, ndim=2] resx
412408
ndarray[int64_t, ndim=2] nobs
413409

@@ -417,19 +413,11 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
417413
raise AssertionError("len(index) != len(labels)")
418414

419415
nobs = np.zeros((<object> out).shape, dtype=np.int64)
420-
{{if name=='object'}}
421-
resx = np.empty((<object> out).shape, dtype=object)
422-
{{else}}
423416
resx = np.empty_like(out)
424-
{{endif}}
425417

426418
N, K = (<object> values).shape
427419

428-
{{if name == "object"}}
429-
if True: # make templating happy
430-
{{else}}
431420
with nogil:
432-
{{endif}}
433421
for i in range(N):
434422
lab = labels[i]
435423
if lab < 0:
@@ -440,7 +428,11 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
440428
val = values[i, j]
441429

442430
# not nan
431+
{{if name == 'int64'}}
432+
if val != {{nan_val}}:
433+
{{else}}
443434
if val == val and val != {{nan_val}}:
435+
{{endif}}
444436
nobs[lab, j] += 1
445437
if nobs[lab, j] == rank:
446438
resx[lab, j] = val

pandas/tests/groupby/test_groupby.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -2252,45 +2252,47 @@ def test_median_empty_bins(self):
22522252
expected = df.groupby(bins).agg(lambda x: x.median())
22532253
assert_frame_equal(result, expected)
22542254

2255-
@pytest.mark.parametrize("dtype", [
2256-
'int8', 'int16', 'int32', 'int64', 'float32', 'float64'])
2257-
@pytest.mark.parametrize("method,data", [
2258-
('first', {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]}),
2259-
('last', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]}),
2260-
('min', {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]}),
2261-
('max', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]}),
2262-
('nth', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}],
2263-
'args': [1]}),
2264-
('count', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 2}],
2265-
'out_type': 'int64'})
2266-
])
2267-
def test_groupby_non_arithmetic_agg_types(self, dtype, method, data):
2255+
def test_groupby_non_arithmetic_agg_types(self):
22682256
# GH9311, GH6620
22692257
df = pd.DataFrame(
22702258
[{'a': 1, 'b': 1},
22712259
{'a': 1, 'b': 2},
22722260
{'a': 2, 'b': 3},
22732261
{'a': 2, 'b': 4}])
22742262

2275-
df['b'] = df.b.astype(dtype)
2263+
dtypes = ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']
22762264

2277-
if 'args' not in data:
2278-
data['args'] = []
2265+
grp_exp = {'first': {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]},
2266+
'last': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]},
2267+
'min': {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]},
2268+
'max': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]},
2269+
'nth': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}],
2270+
'args': [1]},
2271+
'count': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 2}],
2272+
'out_type': 'int64'}}
22792273

2280-
if 'out_type' in data:
2281-
out_type = data['out_type']
2282-
else:
2283-
out_type = dtype
2274+
for dtype in dtypes:
2275+
df_in = df.copy()
2276+
df_in['b'] = df_in.b.astype(dtype)
2277+
2278+
for method, data in compat.iteritems(grp_exp):
2279+
if 'args' not in data:
2280+
data['args'] = []
2281+
2282+
if 'out_type' in data:
2283+
out_type = data['out_type']
2284+
else:
2285+
out_type = dtype
22842286

2285-
exp = data['df']
2286-
df_out = pd.DataFrame(exp)
2287+
exp = data['df']
2288+
df_out = pd.DataFrame(exp)
22872289

2288-
df_out['b'] = df_out.b.astype(out_type)
2289-
df_out.set_index('a', inplace=True)
2290+
df_out['b'] = df_out.b.astype(out_type)
2291+
df_out.set_index('a', inplace=True)
22902292

2291-
grpd = df.groupby('a')
2292-
t = getattr(grpd, method)(*data['args'])
2293-
assert_frame_equal(t, df_out)
2293+
grpd = df_in.groupby('a')
2294+
t = getattr(grpd, method)(*data['args'])
2295+
assert_frame_equal(t, df_out)
22942296

22952297
def test_groupby_non_arithmetic_agg_intlike_precision(self):
22962298
# GH9311, GH6620

0 commit comments

Comments
 (0)