Skip to content

Commit 429ed6e

Browse files
jbrockmendelBlake Hawkins
authored and
Blake Hawkins
committed
CLN: simplify take_2d_multi (pandas-dev#29065)
1 parent b106ef3 commit 429ed6e

File tree

2 files changed

+33
-48
lines changed

2 files changed

+33
-48
lines changed

pandas/core/algorithms.py

+32-47
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ def get_indexer(current_indexer, other_indexer):
13041304
return frame.sort_values(columns, ascending=ascending, kind="mergesort")
13051305

13061306

1307-
# ------- ## ---- #
1307+
# ---- #
13081308
# take #
13091309
# ---- #
13101310

@@ -1712,59 +1712,44 @@ def take_nd(
17121712
take_1d = take_nd
17131713

17141714

1715-
def take_2d_multi(
1716-
arr, indexer, out=None, fill_value=np.nan, mask_info=None, allow_fill=True
1717-
):
1715+
def take_2d_multi(arr, indexer, fill_value=np.nan):
17181716
"""
17191717
Specialized Cython take which sets NaN values in one pass
17201718
"""
1721-
if indexer is None or (indexer[0] is None and indexer[1] is None):
1722-
row_idx = np.arange(arr.shape[0], dtype=np.int64)
1723-
col_idx = np.arange(arr.shape[1], dtype=np.int64)
1724-
indexer = row_idx, col_idx
1725-
dtype, fill_value = arr.dtype, arr.dtype.type()
1726-
else:
1727-
row_idx, col_idx = indexer
1728-
if row_idx is None:
1729-
row_idx = np.arange(arr.shape[0], dtype=np.int64)
1730-
else:
1731-
row_idx = ensure_int64(row_idx)
1732-
if col_idx is None:
1733-
col_idx = np.arange(arr.shape[1], dtype=np.int64)
1734-
else:
1735-
col_idx = ensure_int64(col_idx)
1736-
indexer = row_idx, col_idx
1737-
if not allow_fill:
1719+
# This is only called from one place in DataFrame._reindex_multi,
1720+
# so we know indexer is well-behaved.
1721+
assert indexer is not None
1722+
assert indexer[0] is not None
1723+
assert indexer[1] is not None
1724+
1725+
row_idx, col_idx = indexer
1726+
1727+
row_idx = ensure_int64(row_idx)
1728+
col_idx = ensure_int64(col_idx)
1729+
indexer = row_idx, col_idx
1730+
mask_info = None
1731+
1732+
# check for promotion based on types only (do this first because
1733+
# it's faster than computing a mask)
1734+
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1735+
if dtype != arr.dtype:
1736+
# check if promotion is actually required based on indexer
1737+
row_mask = row_idx == -1
1738+
col_mask = col_idx == -1
1739+
row_needs = row_mask.any()
1740+
col_needs = col_mask.any()
1741+
mask_info = (row_mask, col_mask), (row_needs, col_needs)
1742+
1743+
if not (row_needs or col_needs):
1744+
# if not, then depromote, set fill_value to dummy
1745+
# (it won't be used but we don't want the cython code
1746+
# to crash when trying to cast it to dtype)
17381747
dtype, fill_value = arr.dtype, arr.dtype.type()
1739-
mask_info = None, False
1740-
else:
1741-
# check for promotion based on types only (do this first because
1742-
# it's faster than computing a mask)
1743-
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1744-
if dtype != arr.dtype and (out is None or out.dtype != dtype):
1745-
# check if promotion is actually required based on indexer
1746-
if mask_info is not None:
1747-
(row_mask, col_mask), (row_needs, col_needs) = mask_info
1748-
else:
1749-
row_mask = row_idx == -1
1750-
col_mask = col_idx == -1
1751-
row_needs = row_mask.any()
1752-
col_needs = col_mask.any()
1753-
mask_info = (row_mask, col_mask), (row_needs, col_needs)
1754-
if row_needs or col_needs:
1755-
if out is not None and out.dtype != dtype:
1756-
raise TypeError("Incompatible type for fill_value")
1757-
else:
1758-
# if not, then depromote, set fill_value to dummy
1759-
# (it won't be used but we don't want the cython code
1760-
# to crash when trying to cast it to dtype)
1761-
dtype, fill_value = arr.dtype, arr.dtype.type()
17621748

17631749
# at this point, it's guaranteed that dtype can hold both the arr values
17641750
# and the fill_value
1765-
if out is None:
1766-
out_shape = len(row_idx), len(col_idx)
1767-
out = np.empty(out_shape, dtype=dtype)
1751+
out_shape = len(row_idx), len(col_idx)
1752+
out = np.empty(out_shape, dtype=dtype)
17681753

17691754
func = _take_2d_multi_dict.get((arr.dtype.name, out.dtype.name), None)
17701755
if func is None and arr.dtype != out.dtype:

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4606,7 +4606,7 @@ def _needs_reindex_multi(self, axes, method, level):
46064606
)
46074607

46084608
def _reindex_multi(self, axes, copy, fill_value):
4609-
return NotImplemented
4609+
raise AbstractMethodError(self)
46104610

46114611
def _reindex_with_indexers(
46124612
self, reindexers, fill_value=None, copy=False, allow_dups=False

0 commit comments

Comments
 (0)