Skip to content

Commit 9f439f0

Browse files
committed
Merge pull request #9380 from behzadnouri/i8grby
bug in groupby when key space exceeds int64 bounds
2 parents 391f46a + c7f363b commit 9f439f0

File tree

9 files changed

+174
-73
lines changed

9 files changed

+174
-73
lines changed

bench/bench_groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def g():
4747
from pandas.core.groupby import get_group_index
4848
4949
50-
group_index = get_group_index(label_list, shape).astype('i4')
50+
group_index = get_group_index(label_list, shape,
51+
sort=True, xnull=True).astype('i4')
5152
5253
ngroups = np.prod(shape)
5354

doc/source/whatsnew/v0.16.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ Bug Fixes
156156
- Bug in ``pivot`` and `unstack`` where ``nan`` values would break index alignment (:issue:`4862`, :issue:`7401`, :issue:`7403`, :issue:`7405`, :issue:`7466`)
157157
- Bug in left ``join`` on multi-index with ``sort=True`` or null values (:issue:`9210`).
158158
- Bug in ``MultiIndex`` where inserting new keys would fail (:issue:`9250`).
159+
- Bug in ``groupby`` when key space exceeds ``int64`` bounds (:issue:`9096`).
159160

160161

161162
- Fixed character encoding bug in ``read_stata`` and ``StataReader`` when loading data from a URL (:issue:`9231`).

pandas/core/groupby.py

+56-58
Original file line numberDiff line numberDiff line change
@@ -1367,30 +1367,16 @@ def group_info(self):
13671367

13681368
def _get_compressed_labels(self):
13691369
all_labels = [ping.labels for ping in self.groupings]
1370-
if self._overflow_possible:
1371-
tups = lib.fast_zip(all_labels)
1372-
labs, uniques = algos.factorize(tups)
1370+
if len(all_labels) > 1:
1371+
group_index = get_group_index(all_labels, self.shape,
1372+
sort=True, xnull=True)
1373+
return _compress_group_index(group_index)
13731374

1374-
if self.sort:
1375-
uniques, labs = _reorder_by_uniques(uniques, labs)
1375+
ping = self.groupings[0]
1376+
self.compressed = False
1377+
self._filter_empty_groups = False
13761378

1377-
return labs, uniques
1378-
else:
1379-
if len(all_labels) > 1:
1380-
group_index = get_group_index(all_labels, self.shape)
1381-
comp_ids, obs_group_ids = _compress_group_index(group_index)
1382-
else:
1383-
ping = self.groupings[0]
1384-
comp_ids = ping.labels
1385-
obs_group_ids = np.arange(len(ping.group_index))
1386-
self.compressed = False
1387-
self._filter_empty_groups = False
1388-
1389-
return comp_ids, obs_group_ids
1390-
1391-
@cache_readonly
1392-
def _overflow_possible(self):
1393-
return _int64_overflow_possible(self.shape)
1379+
return ping.labels, np.arange(len(ping.group_index))
13941380

13951381
@cache_readonly
13961382
def ngroups(self):
@@ -1402,15 +1388,13 @@ def result_index(self):
14021388
return MultiIndex.from_arrays(recons, names=self.names)
14031389

14041390
def get_group_levels(self):
1405-
obs_ids = self.group_info[1]
1391+
comp_ids, obs_ids, _ = self.group_info
14061392

14071393
if not self.compressed and len(self.groupings) == 1:
14081394
return [self.groupings[0].group_index]
14091395

1410-
if self._overflow_possible:
1411-
recons_labels = [np.array(x) for x in zip(*obs_ids)]
1412-
else:
1413-
recons_labels = decons_group_index(obs_ids, self.shape)
1396+
recons_labels = decons_obs_group_ids(comp_ids, obs_ids,
1397+
self.shape, (ping.labels for ping in self.groupings))
14141398

14151399
name_list = []
14161400
for ping, labels in zip(self.groupings, recons_labels):
@@ -3490,42 +3474,28 @@ def get_splitter(data, *args, **kwargs):
34903474
# Misc utilities
34913475

34923476

3493-
def get_group_index(label_list, shape):
3477+
def get_group_index(labels, shape, sort, xnull):
34943478
"""
34953479
For the particular label_list, gets the offsets into the hypothetical list
34963480
representing the totally ordered cartesian product of all possible label
3497-
combinations.
3498-
"""
3499-
if len(label_list) == 1:
3500-
return label_list[0]
3501-
3502-
n = len(label_list[0])
3503-
group_index = np.zeros(n, dtype=np.int64)
3504-
mask = np.zeros(n, dtype=bool)
3505-
for i in range(len(shape)):
3506-
stride = np.prod([x for x in shape[i + 1:]], dtype=np.int64)
3507-
group_index += com._ensure_int64(label_list[i]) * stride
3508-
mask |= label_list[i] < 0
3509-
3510-
np.putmask(group_index, mask, -1)
3511-
return group_index
3512-
3513-
3514-
def get_flat_ids(labels, shape, retain_lex_rank):
3515-
"""
3516-
Given a list of labels at each level, returns a flat array of int64 ids
3517-
corresponding to unique tuples across the labels. If `retain_lex_rank`,
3518-
rank of returned ids preserve lexical ranks of labels.
3481+
combinations, *as long as* this space fits within int64 bounds;
3482+
otherwise, though group indices identify unique combinations of
3483+
labels, they cannot be deconstructed.
3484+
- If `sort`, rank of returned ids preserve lexical ranks of labels.
3485+
i.e. returned id's can be used to do lexical sort on labels;
3486+
- If `xnull` nulls (-1 labels) are passed through.
35193487
35203488
Parameters
35213489
----------
35223490
labels: sequence of arrays
35233491
Integers identifying levels at each location
35243492
shape: sequence of ints same length as labels
35253493
Number of unique levels at each location
3526-
retain_lex_rank: boolean
3494+
sort: boolean
35273495
If the ranks of returned ids should match lexical ranks of labels
3528-
3496+
xnull: boolean
3497+
If true nulls are eXcluded. i.e. -1 values in the labels are
3498+
passed through
35293499
Returns
35303500
-------
35313501
An array of type int64 where two elements are equal if their corresponding
@@ -3544,12 +3514,18 @@ def loop(labels, shape):
35443514
stride //= shape[i]
35453515
out += labels[i] * stride
35463516

3517+
if xnull: # exclude nulls
3518+
mask = labels[0] == -1
3519+
for lab in labels[1:nlev]:
3520+
mask |= lab == -1
3521+
out[mask] = -1
3522+
35473523
if nlev == len(shape): # all levels done!
35483524
return out
35493525

35503526
# compress what has been done so far in order to avoid overflow
35513527
# to retain lexical ranks, obs_ids should be sorted
3552-
comp_ids, obs_ids = _compress_group_index(out, sort=retain_lex_rank)
3528+
comp_ids, obs_ids = _compress_group_index(out, sort=sort)
35533529

35543530
labels = [comp_ids] + labels[nlev:]
35553531
shape = [len(obs_ids)] + shape[nlev:]
@@ -3560,9 +3536,10 @@ def maybe_lift(lab, size): # pormote nan values
35603536
return (lab + 1, size + 1) if (lab == -1).any() else (lab, size)
35613537

35623538
labels = map(com._ensure_int64, labels)
3563-
labels, shape = map(list, zip(*map(maybe_lift, labels, shape)))
3539+
if not xnull:
3540+
labels, shape = map(list, zip(*map(maybe_lift, labels, shape)))
35643541

3565-
return loop(labels, shape)
3542+
return loop(list(labels), list(shape))
35663543

35673544

35683545
_INT64_MAX = np.iinfo(np.int64).max
@@ -3578,6 +3555,11 @@ def _int64_overflow_possible(shape):
35783555

35793556
def decons_group_index(comp_labels, shape):
35803557
# reconstruct labels
3558+
if _int64_overflow_possible(shape):
3559+
# at some point group indices are factorized,
3560+
# and may not be deconstructed here! wrong path!
3561+
raise ValueError('cannot deconstruct factorized group indices!')
3562+
35813563
label_list = []
35823564
factor = 1
35833565
y = 0
@@ -3591,12 +3573,25 @@ def decons_group_index(comp_labels, shape):
35913573
return label_list[::-1]
35923574

35933575

3576+
def decons_obs_group_ids(comp_ids, obs_ids, shape, labels):
3577+
"""reconstruct labels from observed ids"""
3578+
from pandas.hashtable import unique_label_indices
3579+
3580+
if not _int64_overflow_possible(shape):
3581+
# obs ids are deconstructable! take the fast route!
3582+
return decons_group_index(obs_ids, shape)
3583+
3584+
i = unique_label_indices(comp_ids)
3585+
i8copy = lambda a: a.astype('i8', subok=False, copy=True)
3586+
return [i8copy(lab[i]) for lab in labels]
3587+
3588+
35943589
def _indexer_from_factorized(labels, shape, compress=True):
35953590
if _int64_overflow_possible(shape):
35963591
indexer = np.lexsort(np.array(labels[::-1]))
35973592
return indexer
35983593

3599-
group_index = get_group_index(labels, shape)
3594+
group_index = get_group_index(labels, shape, sort=True, xnull=True)
36003595

36013596
if compress:
36023597
comp_ids, obs_ids = _compress_group_index(group_index)
@@ -3712,9 +3707,12 @@ def get_key(self, comp_id):
37123707

37133708
def _get_indices_dict(label_list, keys):
37143709
shape = list(map(len, keys))
3715-
ngroups = np.prod(shape)
37163710

3717-
group_index = get_group_index(label_list, shape)
3711+
group_index = get_group_index(label_list, shape, sort=True, xnull=True)
3712+
ngroups = ((group_index.size and group_index.max()) + 1) \
3713+
if _int64_overflow_possible(shape) \
3714+
else np.prod(shape, dtype='i8')
3715+
37183716
sorter = _get_group_index_sorter(group_index, ngroups)
37193717

37203718
sorted_labels = [lab.take(sorter) for lab in label_list]

pandas/core/index.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3229,11 +3229,11 @@ def is_unique(self):
32293229

32303230
@Appender(_shared_docs['duplicated'] % _index_doc_kwargs)
32313231
def duplicated(self, take_last=False):
3232-
from pandas.core.groupby import get_flat_ids
3232+
from pandas.core.groupby import get_group_index
32333233
from pandas.hashtable import duplicated_int64
32343234

32353235
shape = map(len, self.levels)
3236-
ids = get_flat_ids(self.labels, shape, False)
3236+
ids = get_group_index(self.labels, shape, sort=False, xnull=False)
32373237

32383238
return duplicated_int64(ids, take_last)
32393239

pandas/core/reshape.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from pandas.core.categorical import Categorical
1313
from pandas.core.common import (notnull, _ensure_platform_int, _maybe_promote,
1414
isnull)
15-
from pandas.core.groupby import (get_group_index, _compress_group_index,
16-
decons_group_index)
15+
from pandas.core.groupby import get_group_index, _compress_group_index
16+
1717
import pandas.core.common as com
1818
import pandas.algos as algos
1919

@@ -103,10 +103,6 @@ def _make_sorted_values_labels(self):
103103
sizes = [len(x) for x in levs[:v] + levs[v + 1:] + [levs[v]]]
104104

105105
comp_index, obs_ids = get_compressed_ids(to_sort, sizes)
106-
107-
# group_index = get_group_index(to_sort, sizes)
108-
# comp_index, obs_ids = _compress_group_index(group_index)
109-
110106
ngroups = len(obs_ids)
111107

112108
indexer = algos.groupsort_indexer(comp_index, ngroups)[0]
@@ -252,6 +248,8 @@ def _make_new_index(lev, lab):
252248

253249

254250
def _unstack_multiple(data, clocs):
251+
from pandas.core.groupby import decons_obs_group_ids
252+
255253
if len(clocs) == 0:
256254
return data
257255

@@ -271,10 +269,10 @@ def _unstack_multiple(data, clocs):
271269
rnames = [index.names[i] for i in rlocs]
272270

273271
shape = [len(x) for x in clevels]
274-
group_index = get_group_index(clabels, shape)
272+
group_index = get_group_index(clabels, shape, sort=False, xnull=False)
275273

276274
comp_ids, obs_ids = _compress_group_index(group_index, sort=False)
277-
recons_labels = decons_group_index(obs_ids, shape)
275+
recons_labels = decons_obs_group_ids(comp_ids, obs_ids, shape, clabels)
278276

279277
dummy_index = MultiIndex(levels=rlevels + [obs_ids],
280278
labels=rlabels + [comp_ids],
@@ -449,9 +447,9 @@ def _unstack_frame(obj, level):
449447

450448

451449
def get_compressed_ids(labels, sizes):
452-
from pandas.core.groupby import get_flat_ids
450+
from pandas.core.groupby import get_group_index
453451

454-
ids = get_flat_ids(labels, sizes, True)
452+
ids = get_group_index(labels, sizes, sort=True, xnull=False)
455453
return _compress_group_index(ids, sort=True)
456454

457455

pandas/hashtable.pyx

+32-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cnp.import_array()
1717
cnp.import_ufunc()
1818

1919
cdef int64_t iNaT = util.get_nat()
20+
_SIZE_HINT_LIMIT = (1 << 20) + 7
2021

2122
cdef extern from "datetime.h":
2223
bint PyDateTime_Check(object o)
@@ -1073,7 +1074,7 @@ def duplicated_int64(ndarray[int64_t, ndim=1] values, int take_last):
10731074
kh_int64_t * table = kh_init_int64()
10741075
ndarray[uint8_t, ndim=1, cast=True] out = np.empty(n, dtype='bool')
10751076

1076-
kh_resize_int64(table, min(1 << 20, n))
1077+
kh_resize_int64(table, min(n, _SIZE_HINT_LIMIT))
10771078

10781079
if take_last:
10791080
for i from n > i >=0:
@@ -1086,3 +1087,33 @@ def duplicated_int64(ndarray[int64_t, ndim=1] values, int take_last):
10861087

10871088
kh_destroy_int64(table)
10881089
return out
1090+
1091+
1092+
@cython.wraparound(False)
1093+
@cython.boundscheck(False)
1094+
def unique_label_indices(ndarray[int64_t, ndim=1] labels):
1095+
"""
1096+
indices of the first occurrences of the unique labels
1097+
*excluding* -1. equivelent to:
1098+
np.unique(labels, return_index=True)[1]
1099+
"""
1100+
cdef:
1101+
int ret = 0
1102+
Py_ssize_t i, n = len(labels)
1103+
kh_int64_t * table = kh_init_int64()
1104+
Int64Vector idx = Int64Vector()
1105+
ndarray[int64_t, ndim=1] arr
1106+
1107+
kh_resize_int64(table, min(n, _SIZE_HINT_LIMIT))
1108+
1109+
for i in range(n):
1110+
kh_put_int64(table, labels[i], &ret)
1111+
if ret != 0:
1112+
idx.append(i)
1113+
1114+
kh_destroy_int64(table)
1115+
1116+
arr = idx.to_array()
1117+
arr = arr[labels[arr].argsort()]
1118+
1119+
return arr[1:] if arr.size != 0 and labels[arr[0]] == -1 else arr

pandas/tests/test_algos.py

+15
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,21 @@ def test_quantile():
261261
expected = algos.quantile(s.values, [0, .25, .5, .75, 1.])
262262
tm.assert_almost_equal(result, expected)
263263

264+
def test_unique_label_indices():
265+
from pandas.hashtable import unique_label_indices
266+
267+
a = np.random.randint(1, 1 << 10, 1 << 15).astype('i8')
268+
269+
left = unique_label_indices(a)
270+
right = np.unique(a, return_index=True)[1]
271+
272+
tm.assert_array_equal(left, right)
273+
274+
a[np.random.choice(len(a), 10)] = -1
275+
left= unique_label_indices(a)
276+
right = np.unique(a, return_index=True)[1][1:]
277+
tm.assert_array_equal(left, right)
278+
264279
if __name__ == '__main__':
265280
import nose
266281
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],

0 commit comments

Comments
 (0)