Skip to content

Commit 759b0db

Browse files
committed
ENH: Consolidate and further improve performance of take functions
1 parent f2cd3ba commit 759b0db

File tree

9 files changed

+974
-1832
lines changed

9 files changed

+974
-1832
lines changed

pandas/core/common.py

+174-198
Large diffs are not rendered by default.

pandas/core/frame.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2644,8 +2644,9 @@ def _reindex_multi(self, new_index, new_columns, copy, fill_value):
26442644
new_columns, col_indexer = self.columns.reindex(new_columns)
26452645

26462646
if row_indexer is not None and col_indexer is not None:
2647-
new_values = com.take_2d_multi(self.values, row_indexer,
2648-
col_indexer, fill_value=fill_value)
2647+
indexer = row_indexer, col_indexer
2648+
new_values = com.take_2d_multi(self.values, indexer,
2649+
fill_value=fill_value)
26492650
return DataFrame(new_values, index=new_index, columns=new_columns)
26502651
elif row_indexer is not None:
26512652
return self._reindex_with_indexers(new_index, row_indexer,

pandas/core/groupby.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def _aggregate_series_fast(self, obj, func):
897897
dummy = obj[:0].copy()
898898
indexer = _algos.groupsort_indexer(group_index, ngroups)[0]
899899
obj = obj.take(indexer)
900-
group_index = com.ndtake(group_index, indexer)
900+
group_index = com.take_nd(group_index, indexer, allow_fill=False)
901901
grouper = lib.SeriesGrouper(obj, func, group_index, ngroups,
902902
dummy)
903903
result, counts = grouper.get_result()
@@ -1686,7 +1686,9 @@ def aggregate(self, arg, *args, **kwargs):
16861686
zipped = zip(result.index.levels, result.index.labels,
16871687
result.index.names)
16881688
for i, (lev, lab, name) in enumerate(zipped):
1689-
result.insert(i, name, com.ndtake(lev.values, lab))
1689+
result.insert(i, name,
1690+
com.take_nd(lev.values, lab,
1691+
allow_fill=False))
16901692
result = result.consolidate()
16911693
else:
16921694
values = result.index.values
@@ -2133,7 +2135,7 @@ def __init__(self, data, labels, ngroups, axis=0, keep_internal=False):
21332135
@cache_readonly
21342136
def slabels(self):
21352137
# Sorted labels
2136-
return com.ndtake(self.labels, self.sort_idx)
2138+
return com.take_nd(self.labels, self.sort_idx, allow_fill=False)
21372139

21382140
@cache_readonly
21392141
def sort_idx(self):
@@ -2411,11 +2413,11 @@ def _reorder_by_uniques(uniques, labels):
24112413
mask = labels < 0
24122414

24132415
# move labels to right locations (ie, unsort ascending labels)
2414-
labels = com.ndtake(reverse_indexer, labels)
2416+
labels = com.take_nd(reverse_indexer, labels, allow_fill=False)
24152417
np.putmask(labels, mask, -1)
24162418

24172419
# sort observed ids
2418-
uniques = com.ndtake(uniques, sorter)
2420+
uniques = com.take_nd(uniques, sorter, allow_fill=False)
24192421

24202422
return uniques, labels
24212423

pandas/core/index.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pandas.index as _index
1313
from pandas.lib import Timestamp
1414

15-
from pandas.core.common import ndtake
1615
from pandas.util.decorators import cache_readonly
1716
import pandas.core.common as com
1817
from pandas.util import py3compat
@@ -608,7 +607,8 @@ def union(self, other):
608607
indexer = (indexer == -1).nonzero()[0]
609608

610609
if len(indexer) > 0:
611-
other_diff = ndtake(other.values, indexer)
610+
other_diff = com.take_nd(other.values, indexer,
611+
allow_fill=False)
612612
result = com._concat_compat((self.values, other_diff))
613613
try:
614614
result.sort()
@@ -1037,7 +1037,8 @@ def _join_level(self, other, level, how='left', return_indexers=False):
10371037
rev_indexer = lib.get_reverse_indexer(left_lev_indexer,
10381038
len(old_level))
10391039

1040-
new_lev_labels = ndtake(rev_indexer, left.labels[level])
1040+
new_lev_labels = com.take_nd(rev_indexer, left.labels[level],
1041+
allow_fill=False)
10411042
omit_mask = new_lev_labels != -1
10421043

10431044
new_labels = list(left.labels)
@@ -1057,8 +1058,9 @@ def _join_level(self, other, level, how='left', return_indexers=False):
10571058
left_indexer = None
10581059

10591060
if right_lev_indexer is not None:
1060-
right_indexer = ndtake(right_lev_indexer,
1061-
join_index.labels[level])
1061+
right_indexer = com.take_nd(right_lev_indexer,
1062+
join_index.labels[level],
1063+
allow_fill=False)
10621064
else:
10631065
right_indexer = join_index.labels[level]
10641066

@@ -2369,8 +2371,10 @@ def equals(self, other):
23692371
return False
23702372

23712373
for i in xrange(self.nlevels):
2372-
svalues = ndtake(self.levels[i].values, self.labels[i])
2373-
ovalues = ndtake(other.levels[i].values, other.labels[i])
2374+
svalues = com.take_nd(self.levels[i].values, self.labels[i],
2375+
allow_fill=False)
2376+
ovalues = com.take_nd(other.levels[i].values, other.labels[i],
2377+
allow_fill=False)
23742378
if not np.array_equal(svalues, ovalues):
23752379
return False
23762380

pandas/core/internals.py

+28-34
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def merge(self, other):
120120
# union_ref = self.ref_items + other.ref_items
121121
return _merge_blocks([self, other], self.ref_items)
122122

123-
def reindex_axis(self, indexer, mask, needs_masking, axis=0,
124-
fill_value=np.nan):
123+
def reindex_axis(self, indexer, axis=1, fill_value=np.nan, mask_info=None):
125124
"""
126125
Reindex using pre-computed indexer information
127126
"""
128-
new_values = com.take_fast(self.values, indexer,
129-
mask, needs_masking, axis=axis,
130-
fill_value=fill_value)
127+
if axis < 1:
128+
raise AssertionError('axis must be at least 1, got %d' % axis)
129+
new_values = com.take_nd(self.values, indexer, axis,
130+
fill_value=fill_value, mask_info=mask_info)
131131
return make_block(new_values, self.items, self.ref_items)
132132

133133
def reindex_items_from(self, new_ref_items, copy=True):
@@ -146,12 +146,9 @@ def reindex_items_from(self, new_ref_items, copy=True):
146146
new_items = new_ref_items
147147
new_values = self.values.copy() if copy else self.values
148148
else:
149-
mask = indexer != -1
150-
masked_idx = indexer[mask]
151-
152-
new_values = com.take_fast(self.values, masked_idx,
153-
mask=None, needs_masking=False,
154-
axis=0)
149+
masked_idx = indexer[indexer != -1]
150+
new_values = com.take_nd(self.values, masked_idx, axis=0,
151+
allow_fill=False)
155152
new_items = self.items.take(masked_idx)
156153
return make_block(new_values, new_items, new_ref_items)
157154

@@ -221,7 +218,10 @@ def fillna(self, value, inplace=False):
221218
return make_block(new_values, self.items, self.ref_items)
222219

223220
def astype(self, dtype, copy = True, raise_on_error = True):
224-
""" coerce to the new type (if copy=True, return a new copy) raise on an except if raise == True """
221+
"""
222+
Coerce to the new type (if copy=True, return a new copy)
223+
raise on an except if raise == True
224+
"""
225225
try:
226226
newb = make_block(com._astype_nansafe(self.values, dtype, copy = copy),
227227
self.items, self.ref_items)
@@ -231,12 +231,12 @@ def astype(self, dtype, copy = True, raise_on_error = True):
231231
newb = self.copy() if copy else self
232232

233233
if newb.is_numeric and self.is_numeric:
234-
if newb.shape != self.shape or (not copy and newb.itemsize < self.itemsize):
235-
raise TypeError("cannot set astype for copy = [%s] for dtype (%s [%s]) with smaller itemsize that current (%s [%s])" % (copy,
236-
self.dtype.name,
237-
self.itemsize,
238-
newb.dtype.name,
239-
newb.itemsize))
234+
if (newb.shape != self.shape or
235+
(not copy and newb.itemsize < self.itemsize)):
236+
raise TypeError("cannot set astype for copy = [%s] for dtype "
237+
"(%s [%s]) with smaller itemsize that current "
238+
"(%s [%s])" % (copy, self.dtype.name,
239+
self.itemsize, newb.dtype.name, newb.itemsize))
240240
return newb
241241

242242
def convert(self, copy = True, **kwargs):
@@ -356,11 +356,11 @@ def interpolate(self, method='pad', axis=0, inplace=False,
356356

357357
return make_block(values, self.items, self.ref_items)
358358

359-
def take(self, indexer, axis=1, fill_value=np.nan):
359+
def take(self, indexer, axis=1):
360360
if axis < 1:
361361
raise AssertionError('axis must be at least 1, got %d' % axis)
362-
new_values = com.take_fast(self.values, indexer, None, False,
363-
axis=axis, fill_value=fill_value)
362+
new_values = com.take_nd(self.values, indexer, axis=axis,
363+
allow_fill=False)
364364
return make_block(new_values, self.items, self.ref_items)
365365

366366
def get_values(self, dtype):
@@ -1320,15 +1320,9 @@ def reindex_indexer(self, new_axis, indexer, axis=1, fill_value=np.nan):
13201320
if axis == 0:
13211321
return self._reindex_indexer_items(new_axis, indexer, fill_value)
13221322

1323-
mask = indexer == -1
1324-
1325-
# TODO: deal with length-0 case? or does it fall out?
1326-
needs_masking = len(new_axis) > 0 and mask.any()
1327-
13281323
new_blocks = []
13291324
for block in self.blocks:
1330-
newb = block.reindex_axis(indexer, mask, needs_masking,
1331-
axis=axis, fill_value=fill_value)
1325+
newb = block.reindex_axis(indexer, axis=axis, fill_value=fill_value)
13321326
new_blocks.append(newb)
13331327

13341328
new_axes = list(self.axes)
@@ -1354,8 +1348,8 @@ def _reindex_indexer_items(self, new_items, indexer, fill_value):
13541348
continue
13551349

13561350
new_block_items = new_items.take(selector.nonzero()[0])
1357-
new_values = com.take_fast(blk.values, blk_indexer[selector],
1358-
None, False, axis=0)
1351+
new_values = com.take_nd(blk.values, blk_indexer[selector], axis=0,
1352+
allow_fill=False)
13591353
new_blocks.append(make_block(new_values, new_block_items,
13601354
new_items))
13611355

@@ -1419,8 +1413,8 @@ def _make_na_block(self, items, ref_items, fill_value=np.nan):
14191413
return na_block
14201414

14211415
def take(self, indexer, axis=1):
1422-
if axis == 0:
1423-
raise NotImplementedError
1416+
if axis < 1:
1417+
raise AssertionError('axis must be at least 1, got %d' % axis)
14241418

14251419
indexer = com._ensure_platform_int(indexer)
14261420

@@ -1433,8 +1427,8 @@ def take(self, indexer, axis=1):
14331427
new_axes[axis] = self.axes[axis].take(indexer)
14341428
new_blocks = []
14351429
for blk in self.blocks:
1436-
new_values = com.take_fast(blk.values, indexer, None, False,
1437-
axis=axis)
1430+
new_values = com.take_nd(blk.values, indexer, axis=axis,
1431+
allow_fill=False)
14381432
newb = make_block(new_values, blk.items, self.items)
14391433
new_blocks.append(newb)
14401434

pandas/core/panel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def _reindex_multi(self, items, major, minor):
838838
indexer2 = range(len(new_minor))
839839

840840
for i, ind in enumerate(indexer0):
841-
com.take_2d_multi(values[ind], indexer1, indexer2,
841+
com.take_2d_multi(values[ind], (indexer1, indexer2),
842842
out=new_values[i])
843843

844844
return Panel(new_values, items=new_items, major_axis=new_major,

0 commit comments

Comments
 (0)