Skip to content

Commit b86d2ad

Browse files
committed
Added missing obj support
1 parent 978ef7b commit b86d2ad

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

pandas/_libs/groupby.pyx

+24-12
Original file line numberDiff line numberDiff line change
@@ -134,29 +134,37 @@ def group_rank_object(ndarray[float64_t, ndim=2] out,
134134
int tiebreak
135135
Py_ssize_t i, j, N, K
136136
int64_t val_start=0, grp_start=0, dups=0, sum_ranks=0, vals_seen=1
137+
int64_t grp_na_count=0
137138
ndarray[int64_t] _as
138-
bint pct, ascending
139+
ndarray[object] _values
140+
bint pct, ascending, keep_na
139141

140142
tiebreak = tiebreakers[kwargs['ties_method']]
141143
ascending = kwargs['ascending']
142144
pct = kwargs['pct']
143145
keep_na = kwargs['na_option'] == 'keep'
144146
N, K = (<object> values).shape
145147

146-
vals = np.array(values[:, 0], copy=True)
147-
mask = missing.isnaobj(vals)
148+
_values = np.array(values[:, 0], copy=True)
149+
mask = missing.isnaobj(_values)
148150

151+
if ascending ^ (kwargs['na_option'] == 'top'):
152+
nan_value = np.inf
153+
order = (_values, mask, labels)
154+
else:
155+
nan_value = -np.inf
156+
order = (_values, ~mask, labels)
157+
np.putmask(_values, mask, nan_value)
149158
try:
150-
_as = np.lexsort((vals, labels))
159+
_as = np.lexsort(order)
151160
except TypeError:
152161
# lexsort fails when missing data and objects are mixed
153162
# fallback to argsort
154-
order = (vals, mask, labels)
155-
_values = np.asarray(list(zip(order[0], order[1], order[2])),
156-
dtype=[('values', 'O'), ('mask', '?'),
157-
('labels', 'i8')])
158-
_as = np.argsort(_values, kind='mergesort', order=('labels',
159-
'mask', 'values'))
163+
_arr = np.asarray(list(zip(order[0], order[1], order[2])),
164+
dtype=[('values', 'O'), ('mask', '?'),
165+
('labels', 'i8')])
166+
_as = np.argsort(_arr, kind='mergesort', order=('labels',
167+
'mask', 'values'))
160168

161169
if not ascending:
162170
_as = _as[::-1]
@@ -165,7 +173,8 @@ def group_rank_object(ndarray[float64_t, ndim=2] out,
165173
dups += 1
166174
sum_ranks += i - grp_start + 1
167175

168-
if keep_na and mask[_as[i]]:
176+
if keep_na and (values[_as[i], 0] != values[_as[i], 0]):
177+
grp_na_count += 1
169178
out[_as[i], 0] = np.nan
170179
else:
171180
if tiebreak == TIEBREAK_AVERAGE:
@@ -198,8 +207,11 @@ def group_rank_object(ndarray[float64_t, ndim=2] out,
198207
if i == N - 1 or labels[_as[i]] != labels[_as[i+1]]:
199208
if pct:
200209
for j in range(grp_start, i + 1):
201-
out[_as[j], 0] = out[_as[j], 0] / (i - grp_start + 1)
210+
out[_as[j], 0] = out[_as[j], 0] / (i - grp_start + 1
211+
- grp_na_count)
212+
grp_na_count = 0
202213
grp_start = i + 1
214+
vals_seen = 1
203215

204216

205217
cdef inline float64_t median_linear(float64_t* a, int n) nogil:

pandas/_libs/groupby_helper.pxi.in

+1-2
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
471471
N, K = (<object> values).shape
472472

473473
_values = np.array(values[:, 0], copy=True)
474-
475474
mask = np.isnan(_values).astype(np.uint8)
475+
476476
{{if name == 'int64' }}
477477
order = (_values, labels)
478478
{{else}}
@@ -487,7 +487,6 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
487487

488488
_as = np.lexsort(order)
489489

490-
491490
if not ascending:
492491
_as = _as[::-1]
493492

pandas/tests/groupby/test_groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,7 @@ def test_rank_args(self, vals, ties_method, ascending, pct, exp):
19521952

19531953
@pytest.mark.parametrize("vals", [
19541954
[2, 2, np.nan, 8, 2, 6, np.nan, np.nan], # floats
1955-
#['bar', 'bar', np.nan, 'foo', 'bar', 'baz', np.nan, np.nan], # objects
1955+
['bar', 'bar', np.nan, 'foo', 'bar', 'baz', np.nan, np.nan], # objects
19561956
#[pd.Timestamp('2018-01-02'), pd.Timestamp('2018-01-02'), np.nan,
19571957
# pd.Timestamp('2018-01-08'), pd.Timestamp('2018-01-02'),
19581958
# pd.Timestamp('2018-01-06'), np.nan, np.nan]

0 commit comments

Comments
 (0)