Skip to content

Commit c607655

Browse files
committed
Added support for timestamps mixed with nan
1 parent b86d2ad commit c607655

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

pandas/_libs/groupby_helper.pxi.in

+23-6
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
463463
ndarray[{{c_type}}] _values
464464
ndarray[uint8_t] mask
465465
bint pct, ascending, keep_na
466+
{{c_type}} nan_value
466467

467468
tiebreak = tiebreakers[kwargs['ties_method']]
468469
ascending = kwargs['ascending']
@@ -471,19 +472,27 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
471472
N, K = (<object> values).shape
472473

473474
_values = np.array(values[:, 0], copy=True)
475+
{{if name=='int64'}}
476+
mask = (_values == {{nan_val}}).astype(np.uint8)
477+
{{else}}
474478
mask = np.isnan(_values).astype(np.uint8)
479+
{{endif}}
475480

476-
{{if name == 'int64' }}
477-
order = (_values, labels)
478-
{{else}}
479481
if ascending ^ (kwargs['na_option'] == 'top'):
482+
{{if name == 'int64'}}
483+
nan_value = np.iinfo(np.int64).max
484+
{{else}}
480485
nan_value = np.inf
486+
{{endif}}
481487
order = (_values, mask, labels)
482488
else:
489+
{{if name == 'int64'}}
490+
nan_value = np.iinfo(np.int64).min
491+
{{else}}
483492
nan_value = -np.inf
493+
{{endif}}
484494
order = (_values, ~mask, labels)
485495
np.putmask(_values, mask, nan_value)
486-
{{endif}}
487496

488497
_as = np.lexsort(order)
489498

@@ -495,9 +504,9 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
495504
dups += 1
496505
sum_ranks += i - grp_start + 1
497506

498-
if keep_na and (values[_as[i], 0] != values[_as[i], 0]):
507+
if keep_na and _values[_as[i]] == nan_value:
499508
grp_na_count += 1
500-
out[_as[i], 0] = {{nan_val}}
509+
out[_as[i], 0] = nan
501510
else:
502511
if tiebreak == TIEBREAK_AVERAGE:
503512
for j in range(i - dups + 1, i + 1):
@@ -518,11 +527,19 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
518527
for j in range(i - dups + 1, i + 1):
519528
out[_as[j], 0] = vals_seen
520529

530+
{{if name=='int64'}}
531+
if (i == N - 1 or (
532+
(_values[_as[i]] != _values[_as[i+1]]) and not
533+
(_values[_as[i]] == nan_value and
534+
_values[_as[i+1]] == nan_value
535+
))):
536+
{{else}}
521537
if (i == N - 1 or (
522538
(_values[_as[i]] != _values[_as[i+1]]) and not
523539
(isnan(_values[_as[i]]) and
524540
isnan(_values[_as[i+1]])
525541
))):
542+
{{endif}}
526543
dups = sum_ranks = 0
527544
val_start = i
528545
vals_seen += 1

pandas/tests/groupby/test_groupby.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1953,9 +1953,9 @@ def test_rank_args(self, vals, ties_method, ascending, pct, exp):
19531953
@pytest.mark.parametrize("vals", [
19541954
[2, 2, np.nan, 8, 2, 6, np.nan, np.nan], # floats
19551955
['bar', 'bar', np.nan, 'foo', 'bar', 'baz', np.nan, np.nan], # objects
1956-
#[pd.Timestamp('2018-01-02'), pd.Timestamp('2018-01-02'), np.nan,
1957-
# pd.Timestamp('2018-01-08'), pd.Timestamp('2018-01-02'),
1958-
# pd.Timestamp('2018-01-06'), np.nan, np.nan]
1956+
[pd.Timestamp('2018-01-02'), pd.Timestamp('2018-01-02'), np.nan,
1957+
pd.Timestamp('2018-01-08'), pd.Timestamp('2018-01-02'),
1958+
pd.Timestamp('2018-01-06'), np.nan, np.nan]
19591959
])
19601960
@pytest.mark.parametrize("ties_method,ascending,na_option,pct,exp", [
19611961
('average', True, 'keep', False, DataFrame(

0 commit comments

Comments
 (0)