Skip to content

Commit d4730e6

Browse files
WillAydjreback
authored andcommitted
Consolidate nth / last object Groupby Implementations (#19610)
1 parent c1068d9 commit d4730e6

File tree

3 files changed

+47
-140
lines changed

3 files changed

+47
-140
lines changed

pandas/_libs/groupby.pyx

-99
Original file line numberDiff line numberDiff line change
@@ -26,105 +26,6 @@ cdef double NaN = <double> np.NaN
2626
cdef double nan = NaN
2727

2828

29-
# TODO: aggregate multiple columns in single pass
30-
# ----------------------------------------------------------------------
31-
# first, nth, last
32-
33-
34-
@cython.boundscheck(False)
35-
@cython.wraparound(False)
36-
def group_nth_object(ndarray[object, ndim=2] out,
37-
ndarray[int64_t] counts,
38-
ndarray[object, ndim=2] values,
39-
ndarray[int64_t] labels,
40-
int64_t rank,
41-
Py_ssize_t min_count=-1):
42-
"""
43-
Only aggregates on axis=0
44-
"""
45-
cdef:
46-
Py_ssize_t i, j, N, K, lab
47-
object val
48-
float64_t count
49-
ndarray[int64_t, ndim=2] nobs
50-
ndarray[object, ndim=2] resx
51-
52-
assert min_count == -1, "'min_count' only used in add and prod"
53-
54-
nobs = np.zeros((<object> out).shape, dtype=np.int64)
55-
resx = np.empty((<object> out).shape, dtype=object)
56-
57-
N, K = (<object> values).shape
58-
59-
for i in range(N):
60-
lab = labels[i]
61-
if lab < 0:
62-
continue
63-
64-
counts[lab] += 1
65-
for j in range(K):
66-
val = values[i, j]
67-
68-
# not nan
69-
if val == val:
70-
nobs[lab, j] += 1
71-
if nobs[lab, j] == rank:
72-
resx[lab, j] = val
73-
74-
for i in range(len(counts)):
75-
for j in range(K):
76-
if nobs[i, j] == 0:
77-
out[i, j] = <object> nan
78-
else:
79-
out[i, j] = resx[i, j]
80-
81-
82-
@cython.boundscheck(False)
83-
@cython.wraparound(False)
84-
def group_last_object(ndarray[object, ndim=2] out,
85-
ndarray[int64_t] counts,
86-
ndarray[object, ndim=2] values,
87-
ndarray[int64_t] labels,
88-
Py_ssize_t min_count=-1):
89-
"""
90-
Only aggregates on axis=0
91-
"""
92-
cdef:
93-
Py_ssize_t i, j, N, K, lab
94-
object val
95-
float64_t count
96-
ndarray[object, ndim=2] resx
97-
ndarray[int64_t, ndim=2] nobs
98-
99-
assert min_count == -1, "'min_count' only used in add and prod"
100-
101-
nobs = np.zeros((<object> out).shape, dtype=np.int64)
102-
resx = np.empty((<object> out).shape, dtype=object)
103-
104-
N, K = (<object> values).shape
105-
106-
for i in range(N):
107-
lab = labels[i]
108-
if lab < 0:
109-
continue
110-
111-
counts[lab] += 1
112-
for j in range(K):
113-
val = values[i, j]
114-
115-
# not nan
116-
if val == val:
117-
nobs[lab, j] += 1
118-
resx[lab, j] = val
119-
120-
for i in range(len(counts)):
121-
for j in range(K):
122-
if nobs[i, j] == 0:
123-
out[i, j] = nan
124-
else:
125-
out[i, j] = resx[i, j]
126-
127-
12829
cdef inline float64_t median_linear(float64_t* a, int n) nogil:
12930
cdef int i, j, na_count = 0
13031
cdef float64_t result

pandas/_libs/groupby_helper.pxi.in

+20-12
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
325325
# name, c_type, dest_type2, nan_val
326326
dtypes = [('float64', 'float64_t', 'float64_t', 'NAN'),
327327
('float32', 'float32_t', 'float32_t', 'NAN'),
328-
('int64', 'int64_t', 'int64_t', 'iNaT')]
328+
('int64', 'int64_t', 'int64_t', 'iNaT'),
329+
('object', 'object', 'object', 'NAN')]
329330

330331
def get_dispatch(dtypes):
331332

@@ -350,7 +351,7 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
350351
"""
351352
cdef:
352353
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
353-
{{dest_type2}} val, count
354+
{{dest_type2}} val
354355
ndarray[{{dest_type2}}, ndim=2] resx
355356
ndarray[int64_t, ndim=2] nobs
356357

@@ -360,11 +361,19 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
360361
raise AssertionError("len(index) != len(labels)")
361362

362363
nobs = np.zeros((<object> out).shape, dtype=np.int64)
364+
{{if name=='object'}}
365+
resx = np.empty((<object> out).shape, dtype=object)
366+
{{else}}
363367
resx = np.empty_like(out)
368+
{{endif}}
364369

365370
N, K = (<object> values).shape
366371

372+
{{if name == "object"}}
373+
if True: # make templating happy
374+
{{else}}
367375
with nogil:
376+
{{endif}}
368377
for i in range(N):
369378
lab = labels[i]
370379
if lab < 0:
@@ -375,11 +384,7 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
375384
val = values[i, j]
376385

377386
# not nan
378-
{{if name == 'int64'}}
379-
if val != {{nan_val}}:
380-
{{else}}
381387
if val == val and val != {{nan_val}}:
382-
{{endif}}
383388
nobs[lab, j] += 1
384389
resx[lab, j] = val
385390

@@ -390,7 +395,6 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
390395
else:
391396
out[i, j] = resx[i, j]
392397

393-
394398
@cython.wraparound(False)
395399
@cython.boundscheck(False)
396400
def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
@@ -403,7 +407,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
403407
"""
404408
cdef:
405409
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
406-
{{dest_type2}} val, count
410+
{{dest_type2}} val
407411
ndarray[{{dest_type2}}, ndim=2] resx
408412
ndarray[int64_t, ndim=2] nobs
409413

@@ -413,11 +417,19 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
413417
raise AssertionError("len(index) != len(labels)")
414418

415419
nobs = np.zeros((<object> out).shape, dtype=np.int64)
420+
{{if name=='object'}}
421+
resx = np.empty((<object> out).shape, dtype=object)
422+
{{else}}
416423
resx = np.empty_like(out)
424+
{{endif}}
417425

418426
N, K = (<object> values).shape
419427

428+
{{if name == "object"}}
429+
if True: # make templating happy
430+
{{else}}
420431
with nogil:
432+
{{endif}}
421433
for i in range(N):
422434
lab = labels[i]
423435
if lab < 0:
@@ -428,11 +440,7 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
428440
val = values[i, j]
429441

430442
# not nan
431-
{{if name == 'int64'}}
432-
if val != {{nan_val}}:
433-
{{else}}
434443
if val == val and val != {{nan_val}}:
435-
{{endif}}
436444
nobs[lab, j] += 1
437445
if nobs[lab, j] == rank:
438446
resx[lab, j] = val

pandas/tests/groupby/test_groupby.py

+27-29
Original file line numberDiff line numberDiff line change
@@ -2252,47 +2252,45 @@ def test_median_empty_bins(self):
22522252
expected = df.groupby(bins).agg(lambda x: x.median())
22532253
assert_frame_equal(result, expected)
22542254

2255-
def test_groupby_non_arithmetic_agg_types(self):
2255+
@pytest.mark.parametrize("dtype", [
2256+
'int8', 'int16', 'int32', 'int64', 'float32', 'float64'])
2257+
@pytest.mark.parametrize("method,data", [
2258+
('first', {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]}),
2259+
('last', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]}),
2260+
('min', {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]}),
2261+
('max', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]}),
2262+
('nth', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}],
2263+
'args': [1]}),
2264+
('count', {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 2}],
2265+
'out_type': 'int64'})
2266+
])
2267+
def test_groupby_non_arithmetic_agg_types(self, dtype, method, data):
22562268
# GH9311, GH6620
22572269
df = pd.DataFrame(
22582270
[{'a': 1, 'b': 1},
22592271
{'a': 1, 'b': 2},
22602272
{'a': 2, 'b': 3},
22612273
{'a': 2, 'b': 4}])
22622274

2263-
dtypes = ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']
2264-
2265-
grp_exp = {'first': {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]},
2266-
'last': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]},
2267-
'min': {'df': [{'a': 1, 'b': 1}, {'a': 2, 'b': 3}]},
2268-
'max': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}]},
2269-
'nth': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 4}],
2270-
'args': [1]},
2271-
'count': {'df': [{'a': 1, 'b': 2}, {'a': 2, 'b': 2}],
2272-
'out_type': 'int64'}}
2275+
df['b'] = df.b.astype(dtype)
22732276

2274-
for dtype in dtypes:
2275-
df_in = df.copy()
2276-
df_in['b'] = df_in.b.astype(dtype)
2277+
if 'args' not in data:
2278+
data['args'] = []
22772279

2278-
for method, data in compat.iteritems(grp_exp):
2279-
if 'args' not in data:
2280-
data['args'] = []
2281-
2282-
if 'out_type' in data:
2283-
out_type = data['out_type']
2284-
else:
2285-
out_type = dtype
2280+
if 'out_type' in data:
2281+
out_type = data['out_type']
2282+
else:
2283+
out_type = dtype
22862284

2287-
exp = data['df']
2288-
df_out = pd.DataFrame(exp)
2285+
exp = data['df']
2286+
df_out = pd.DataFrame(exp)
22892287

2290-
df_out['b'] = df_out.b.astype(out_type)
2291-
df_out.set_index('a', inplace=True)
2288+
df_out['b'] = df_out.b.astype(out_type)
2289+
df_out.set_index('a', inplace=True)
22922290

2293-
grpd = df_in.groupby('a')
2294-
t = getattr(grpd, method)(*data['args'])
2295-
assert_frame_equal(t, df_out)
2291+
grpd = df.groupby('a')
2292+
t = getattr(grpd, method)(*data['args'])
2293+
assert_frame_equal(t, df_out)
22962294

22972295
def test_groupby_non_arithmetic_agg_intlike_precision(self):
22982296
# GH9311, GH6620

0 commit comments

Comments
 (0)