Skip to content

Commit f9f198e

Browse files
committed
ENH: groupby refactoring to use khash, add sort option, GH #595
1 parent 330886f commit f9f198e

File tree

9 files changed

+99
-154
lines changed

9 files changed

+99
-154
lines changed

pandas/core/common.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -383,21 +383,19 @@ def set_printoptions(precision=None, column_space=None, max_rows=None,
383383
out how big the terminal is and will not display more rows or/and
384384
columns that can fit on it.
385385
"""
386-
global GlobalPrintConfig
387386
if precision is not None:
388-
GlobalPrintConfig.precision = precision
387+
print_config.precision = precision
389388
if column_space is not None:
390-
GlobalPrintConfig.column_space = column_space
389+
print_config.column_space = column_space
391390
if max_rows is not None:
392-
GlobalPrintConfig.max_rows = max_rows
391+
print_config.max_rows = max_rows
393392
if max_columns is not None:
394-
GlobalPrintConfig.max_columns = max_columns
393+
print_config.max_columns = max_columns
395394
if colheader_justify is not None:
396-
GlobalPrintConfig.colheader_justify = colheader_justify
395+
print_config.colheader_justify = colheader_justify
397396

398397
def reset_printoptions():
399-
global GlobalPrintConfig
400-
GlobalPrintConfig.reset()
398+
print_config.reset()
401399

402400
class EngFormatter(object):
403401
"""
@@ -503,9 +501,8 @@ def set_eng_float_format(precision=None, accuracy=3, use_eng_prefix=False):
503501
"being renamed to 'accuracy'" , FutureWarning)
504502
accuracy = precision
505503

506-
global GlobalPrintConfig
507-
GlobalPrintConfig.float_format = EngFormatter(accuracy, use_eng_prefix)
508-
GlobalPrintConfig.column_space = max(12, accuracy + 9)
504+
print_config.float_format = EngFormatter(accuracy, use_eng_prefix)
505+
print_config.column_space = max(12, accuracy + 9)
509506

510507
#_float_format = None
511508
#_column_space = 12
@@ -526,7 +523,7 @@ def _float_format_default(v, width=None):
526523
to fit the width, reformat it to that width.
527524
"""
528525

529-
fmt_str = '%% .%dg' % GlobalPrintConfig.precision
526+
fmt_str = '%% .%dg' % print_config.precision
530527
formatted = fmt_str % v
531528

532529
if width is None:
@@ -588,8 +585,8 @@ def _make_float_format(x):
588585

589586
if float_format:
590587
formatted = float_format(x)
591-
elif GlobalPrintConfig.float_format:
592-
formatted = GlobalPrintConfig.float_format(x)
588+
elif print_config.float_format:
589+
formatted = print_config.float_format(x)
593590
else:
594591
formatted = _float_format_default(x, col_width)
595592

@@ -621,7 +618,7 @@ def __init__(self):
621618
def reset(self):
622619
self.__init__()
623620

624-
GlobalPrintConfig = _GlobalPrintConfig()
621+
print_config = _GlobalPrintConfig()
625622

626623
#------------------------------------------------------------------------------
627624
# miscellaneous python tools

pandas/core/format.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, frame, buf=None, columns=None, col_space=None,
6464
self.index = index
6565

6666
if justify is None:
67-
self.justify = com.GlobalPrintConfig.colheader_justify
67+
self.justify = com.print_config.colheader_justify
6868
else:
6969
self.justify = justify
7070

pandas/core/frame.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,12 @@ def __repr__(self):
413413
"""
414414
Return a string representation for a particular DataFrame
415415
"""
416+
config = com.print_config
417+
416418
terminal_width, terminal_height = get_terminal_size()
417-
max_rows = (terminal_height if com.GlobalPrintConfig.max_rows == 0
418-
else com.GlobalPrintConfig.max_rows)
419-
max_columns = com.GlobalPrintConfig.max_columns
419+
max_rows = (terminal_height if config.max_rows == 0
420+
else config.max_rows)
421+
max_columns = config.max_columns
420422

421423
if max_columns > 0:
422424
buf = StringIO()

pandas/core/generic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get(self, key, default=None):
7979
except KeyError:
8080
return default
8181

82-
def groupby(self, by=None, axis=0, level=None, as_index=True):
82+
def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True):
8383
"""
8484
Group series using mapper (dict or key function, apply given function
8585
to group, return result as series) or by a series of columns
@@ -99,6 +99,8 @@ def groupby(self, by=None, axis=0, level=None, as_index=True):
9999
For aggregated output, return object with group labels as the
100100
index. Only relevant for DataFrame input. as_index=False is
101101
effectively "SQL-style" grouped output
102+
sort : boolean, default True
103+
Sort group keys. Get better performance by turning this off
102104
103105
Examples
104106
--------
@@ -116,7 +118,8 @@ def groupby(self, by=None, axis=0, level=None, as_index=True):
116118
GroupBy object
117119
"""
118120
from pandas.core.groupby import groupby
119-
return groupby(self, by, axis=axis, level=level, as_index=as_index)
121+
return groupby(self, by, axis=axis, level=level, as_index=as_index,
122+
sort=sort)
120123

121124
def select(self, crit, axis=0):
122125
"""

pandas/core/groupby.py

+56-50
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ class GroupBy(object):
8585
"""
8686

8787
def __init__(self, obj, grouper=None, axis=0, level=None,
88-
groupings=None, exclusions=None, column=None, as_index=True):
88+
groupings=None, exclusions=None, column=None, as_index=True,
89+
sort=True):
8990
self._column = column
9091

9192
if isinstance(obj, NDFrame):
@@ -105,10 +106,11 @@ def __init__(self, obj, grouper=None, axis=0, level=None,
105106

106107
self.as_index = as_index
107108
self.grouper = grouper
109+
self.sort = sort
108110

109111
if groupings is None:
110112
groupings, exclusions = _get_groupings(obj, grouper, axis=axis,
111-
level=level)
113+
level=level, sort=sort)
112114

113115
self.groupings = groupings
114116
self.exclusions = set(exclusions) if exclusions else set()
@@ -132,6 +134,7 @@ def indices(self):
132134
if len(self.groupings) == 1:
133135
return self.primary.indices
134136
else:
137+
# TODO: this is massively inefficient
135138
to_groupby = zip(*(ping.grouper for ping in self.groupings))
136139
to_groupby = Index(to_groupby)
137140
return lib.groupby_indices(to_groupby)
@@ -149,7 +152,7 @@ def _obj_with_exclusions(self):
149152

150153
@property
151154
def _group_shape(self):
152-
return tuple(len(ping.counts) for ping in self.groupings)
155+
return tuple(ping.ngroups for ping in self.groupings)
153156

154157
def __getattr__(self, attr):
155158
if hasattr(self.obj, attr):
@@ -525,11 +528,13 @@ class Grouping(object):
525528
* group_index : unique groups
526529
* groups : dict of {group -> label_list}
527530
"""
528-
def __init__(self, index, grouper=None, name=None, level=None):
531+
def __init__(self, index, grouper=None, name=None, level=None,
532+
sort=True):
529533
self.name = name
530534
self.level = level
531535
self.grouper = _convert_grouper(index, grouper)
532536
self.index = index
537+
self.sort = sort
533538

534539
# right place for this?
535540
if isinstance(grouper, Series) and name is None:
@@ -576,6 +581,10 @@ def __iter__(self):
576581
_counts = None
577582
_group_index = None
578583

584+
@property
585+
def ngroups(self):
586+
return len(self.group_index)
587+
579588
@cache_readonly
580589
def indices(self):
581590
return _groupby_indices(self.grouper)
@@ -589,38 +598,58 @@ def labels(self):
589598
@property
590599
def ids(self):
591600
if self._ids is None:
592-
if self._was_factor:
593-
index = self._group_index
594-
self._ids = dict(zip(range(len(index)), index))
595-
else:
596-
self._make_labels()
601+
index = self.group_index
602+
self._ids = dict(zip(range(len(index)), index))
597603
return self._ids
598604

599605
@property
600606
def counts(self):
601607
if self._counts is None:
602-
self._make_labels()
608+
if self._was_factor:
609+
self._counts = lib.group_count(self.labels, self.ngroups)
610+
else:
611+
self._make_labels()
603612
return self._counts
604613

605614
@property
606615
def group_index(self):
607616
if self._group_index is None:
608-
ids = self.ids
609-
values = np.arange(len(self.ids), dtype='O')
610-
self._group_index = Index(lib.lookup_values(values, ids),
611-
name=self.name)
617+
self._make_labels()
618+
619+
# ids = self.ids
620+
# values = np.arange(len(self.ids), dtype='O')
621+
# self._group_index = Index(lib.lookup_values(values, ids),
622+
# name=self.name)
612623
return self._group_index
613624

614625
def _make_labels(self):
615626
if self._was_factor: # pragma: no cover
616627
raise Exception('Should not call this method grouping by level')
617628
else:
618-
ids, labels, counts = _group_labels(self.grouper)
619-
sids, slabels, scounts = sort_group_labels(ids, labels, counts)
629+
values = self.grouper
630+
if values.dtype != np.object_:
631+
values = values.astype('O')
632+
633+
# khash
634+
rizer = lib.Factorizer(len(values))
635+
labels, counts = rizer.factorize(values, sort=False)
636+
637+
uniques = Index(rizer.uniques, name=self.name)
638+
if self.sort and len(counts) > 0:
639+
sorter = uniques.argsort()
640+
reverse_indexer = np.empty(len(sorter), dtype=np.int32)
641+
reverse_indexer.put(sorter, np.arange(len(sorter)))
642+
643+
mask = labels < 0
644+
labels = reverse_indexer.take(labels)
645+
np.putmask(labels, mask, -1)
620646

621-
self._labels = slabels
622-
self._ids = sids
623-
self._counts = scounts
647+
uniques = uniques.take(sorter)
648+
counts = counts.take(sorter)
649+
650+
self._labels = labels
651+
self._group_index = uniques
652+
self._counts = counts
624653

625654
_groups = None
626655
@property
@@ -629,7 +658,8 @@ def groups(self):
629658
self._groups = self.index.groupby(self.grouper)
630659
return self._groups
631660

632-
def _get_groupings(obj, grouper=None, axis=0, level=None):
661+
662+
def _get_groupings(obj, grouper=None, axis=0, level=None, sort=True):
633663
group_axis = obj._get_axis(axis)
634664

635665
if level is not None and not isinstance(group_axis, MultiIndex):
@@ -655,7 +685,7 @@ def _get_groupings(obj, grouper=None, axis=0, level=None):
655685
exclusions.append(gpr)
656686
name = gpr
657687
gpr = obj[gpr]
658-
ping = Grouping(group_axis, gpr, name=name, level=level)
688+
ping = Grouping(group_axis, gpr, name=name, level=level, sort=sort)
659689
if ping.name is None:
660690
ping.name = 'key_%d' % i
661691
groupings.append(ping)
@@ -785,7 +815,7 @@ def _get_index():
785815
index = MultiIndex.from_tuples(keys, names=key_names)
786816
else:
787817
ping = self.groupings[0]
788-
if len(keys) == len(ping.counts):
818+
if len(keys) == ping.ngroups:
789819
index = ping.group_index
790820
index.name = key_names[0]
791821
else:
@@ -1056,7 +1086,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
10561086
key_index = MultiIndex.from_tuples(keys, names=key_names)
10571087
else:
10581088
ping = self.groupings[0]
1059-
if len(keys) == len(ping.counts):
1089+
if len(keys) == ping.ngroups:
10601090
key_index = ping.group_index
10611091
key_index.name = key_names[0]
10621092

@@ -1235,6 +1265,9 @@ def slicer(data, slob):
12351265
yield i, slicer(sorted_data, slice(start, end))
12361266

12371267
def get_group_index(label_list, shape):
1268+
if len(label_list) == 1:
1269+
return label_list[0]
1270+
12381271
n = len(label_list[0])
12391272
group_index = np.zeros(n, dtype=int)
12401273
mask = np.zeros(n, dtype=bool)
@@ -1353,11 +1386,6 @@ def _groupby_indices(values):
13531386
values = values.astype('O')
13541387
return lib.groupby_indices(values)
13551388

1356-
def _group_labels(values):
1357-
if values.dtype != np.object_:
1358-
values = values.astype('O')
1359-
return lib.group_labels(values)
1360-
13611389
def _ensure_platform_int(labels):
13621390
if labels.dtype != np.int_:
13631391
labels = labels.astype(np.int_)
@@ -1367,25 +1395,3 @@ def _ensure_int64(labels):
13671395
if labels.dtype != np.int64:
13681396
labels = labels.astype(np.int64)
13691397
return labels
1370-
1371-
def sort_group_labels(ids, labels, counts):
1372-
n = len(ids)
1373-
1374-
# corner all NA case
1375-
if n == 0:
1376-
return ids, labels, counts
1377-
1378-
rng = np.arange(n)
1379-
values = Series(ids, index=rng, dtype=object).values
1380-
indexer = values.argsort()
1381-
1382-
reverse_indexer = np.empty(n, dtype=np.int32)
1383-
reverse_indexer.put(indexer, np.arange(n))
1384-
1385-
new_labels = reverse_indexer.take(labels)
1386-
np.putmask(new_labels, labels == -1, -1)
1387-
1388-
new_ids = dict(izip(rng, values.take(indexer)))
1389-
new_counts = counts.take(indexer)
1390-
1391-
return new_ids, new_labels, new_counts

pandas/core/series.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ def __setslice__(self, i, j, value):
456456
def __repr__(self):
457457
"""Clean string representation of a Series"""
458458
width, height = get_terminal_size()
459-
max_rows = (height if com.GlobalPrintConfig.max_rows == 0
460-
else com.GlobalPrintConfig.max_rows)
459+
max_rows = (height if com.print_config.max_rows == 0
460+
else com.print_config.max_rows)
461461
if len(self.index) > max_rows:
462462
result = self._tidy_repr(min(30, max_rows - 4))
463463
elif len(self.index) > 0:
@@ -518,7 +518,7 @@ def _get_repr(self, name=False, print_header=False, length=True,
518518
padSpace = min(maxlen, 60)
519519

520520
if float_format is None:
521-
float_format = com.GlobalPrintConfig.float_format
521+
float_format = com.print_config.float_format
522522
if float_format is None:
523523
float_format = com._float_format_default
524524

0 commit comments

Comments
 (0)