Skip to content

Commit b17379b

Browse files
authored
REF: reindex_indexer up-front to simplify JoinUnit (#43384)
1 parent 8b4c4b4 commit b17379b

File tree

1 file changed

+44
-72
lines changed

1 file changed

+44
-72
lines changed

pandas/core/internals/concat.py

+44-72
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666
if TYPE_CHECKING:
6767
from pandas import Index
68+
from pandas.core.internals.blocks import Block
6869

6970

7071
def _concatenate_array_managers(
@@ -300,76 +301,62 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
300301
mgr_shape_list[ax] = len(indexer)
301302
mgr_shape = tuple(mgr_shape_list)
302303

303-
has_column_indexer = False
304+
assert 0 not in indexers
304305

305-
if 0 in indexers:
306-
has_column_indexer = True
307-
ax0_indexer = indexers.pop(0)
308-
blknos = algos.take_nd(mgr.blknos, ax0_indexer, fill_value=-1)
309-
blklocs = algos.take_nd(mgr.blklocs, ax0_indexer, fill_value=-1)
310-
else:
311-
312-
if mgr.is_single_block:
313-
blk = mgr.blocks[0]
314-
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape, indexers))]
306+
if mgr.is_single_block:
307+
blk = mgr.blocks[0]
308+
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape, indexers))]
315309

316-
blknos = mgr.blknos
317-
blklocs = mgr.blklocs
310+
blknos = mgr.blknos
311+
blklocs = mgr.blklocs
318312

319313
plan = []
320314
for blkno, placements in libinternals.get_blkno_placements(blknos, group=False):
321315

322316
assert placements.is_slice_like
317+
assert blkno != -1
323318

324319
join_unit_indexers = indexers.copy()
325320

326321
shape_list = list(mgr_shape)
327322
shape_list[0] = len(placements)
328323
shape = tuple(shape_list)
329324

330-
if blkno == -1:
331-
# only reachable in the `0 in indexers` case
332-
unit = JoinUnit(None, shape)
333-
else:
334-
blk = mgr.blocks[blkno]
335-
ax0_blk_indexer = blklocs[placements.indexer]
336-
337-
unit_no_ax0_reindexing = (
338-
len(placements) == len(blk.mgr_locs)
339-
and
340-
# Fastpath detection of join unit not
341-
# needing to reindex its block: no ax0
342-
# reindexing took place and block
343-
# placement was sequential before.
344-
(
345-
(
346-
not has_column_indexer
347-
and blk.mgr_locs.is_slice_like
348-
and blk.mgr_locs.as_slice.step == 1
349-
)
350-
or
351-
# Slow-ish detection: all indexer locs
352-
# are sequential (and length match is
353-
# checked above).
354-
(np.diff(ax0_blk_indexer) == 1).all()
355-
)
325+
blk = mgr.blocks[blkno]
326+
ax0_blk_indexer = blklocs[placements.indexer]
327+
328+
unit_no_ax0_reindexing = (
329+
len(placements) == len(blk.mgr_locs)
330+
and
331+
# Fastpath detection of join unit not
332+
# needing to reindex its block: no ax0
333+
# reindexing took place and block
334+
# placement was sequential before.
335+
(
336+
(blk.mgr_locs.is_slice_like and blk.mgr_locs.as_slice.step == 1)
337+
or
338+
# Slow-ish detection: all indexer locs
339+
# are sequential (and length match is
340+
# checked above).
341+
(np.diff(ax0_blk_indexer) == 1).all()
356342
)
343+
)
357344

358-
# Omit indexer if no item reindexing is required.
359-
if unit_no_ax0_reindexing:
360-
join_unit_indexers.pop(0, None)
361-
else:
362-
join_unit_indexers[0] = ax0_blk_indexer
345+
# Omit indexer if no item reindexing is required.
346+
if unit_no_ax0_reindexing:
347+
join_unit_indexers.pop(0, None)
348+
else:
349+
join_unit_indexers[0] = ax0_blk_indexer
363350

364-
unit = JoinUnit(blk, shape, join_unit_indexers)
351+
unit = JoinUnit(blk, shape, join_unit_indexers)
365352

366353
plan.append((placements, unit))
367354

368355
return plan
369356

370357

371358
class JoinUnit:
372-
def __init__(self, block, shape: Shape, indexers=None):
359+
def __init__(self, block: Block, shape: Shape, indexers=None):
373360
# Passing shape explicitly is required for cases when block is None.
374361
# Note: block is None implies indexers is None, but not vice-versa
375362
if indexers is None:
@@ -393,7 +380,7 @@ def needs_filling(self) -> bool:
393380
@cache_readonly
394381
def dtype(self):
395382
blk = self.block
396-
if blk is None:
383+
if blk.values.dtype.kind == "V":
397384
raise AssertionError("Block is None, no dtype")
398385

399386
if not self.needs_filling:
@@ -407,8 +394,6 @@ def _is_valid_na_for(self, dtype: DtypeObj) -> bool:
407394
"""
408395
if not self.is_na:
409396
return False
410-
if self.block is None:
411-
return True
412397
if self.block.dtype.kind == "V":
413398
return True
414399

@@ -435,8 +420,6 @@ def _is_valid_na_for(self, dtype: DtypeObj) -> bool:
435420
@cache_readonly
436421
def is_na(self) -> bool:
437422
blk = self.block
438-
if blk is None:
439-
return True
440423
if blk.dtype.kind == "V":
441424
return True
442425

@@ -464,6 +447,8 @@ def is_na(self) -> bool:
464447
return all(isna_all(row) for row in values)
465448

466449
def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
450+
values: ArrayLike
451+
467452
if upcasted_na is None and self.block.dtype.kind != "V":
468453
# No upcasting is necessary
469454
fill_value = self.block.fill_value
@@ -472,9 +457,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
472457
fill_value = upcasted_na
473458

474459
if self._is_valid_na_for(empty_dtype):
475-
# note: always holds when self.block is None
476-
# or self.block.dtype.kind == "V"
477-
blk_dtype = getattr(self.block, "dtype", None)
460+
# note: always holds when self.block.dtype.kind == "V"
461+
blk_dtype = self.block.dtype
478462

479463
if blk_dtype == np.dtype("object"):
480464
# we want to avoid filling with np.nan if we are
@@ -551,9 +535,7 @@ def _concatenate_join_units(
551535

552536
empty_dtype = _get_empty_dtype(join_units)
553537

554-
has_none_blocks = any(
555-
unit.block is None or unit.block.dtype.kind == "V" for unit in join_units
556-
)
538+
has_none_blocks = any(unit.block.dtype.kind == "V" for unit in join_units)
557539
upcasted_na = _dtype_to_na_value(empty_dtype, has_none_blocks)
558540

559541
to_concat = [
@@ -629,28 +611,18 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
629611
"""
630612
if len(join_units) == 1:
631613
blk = join_units[0].block
632-
if blk is None:
633-
return np.dtype(np.float64)
634614
return blk.dtype
635615

636616
if _is_uniform_reindex(join_units):
637617
# FIXME: integrate property
638618
empty_dtype = join_units[0].block.dtype
639619
return empty_dtype
640620

641-
has_none_blocks = any(
642-
unit.block is None or unit.block.dtype.kind == "V" for unit in join_units
643-
)
621+
has_none_blocks = any(unit.block.dtype.kind == "V" for unit in join_units)
644622

645-
dtypes = [
646-
unit.dtype for unit in join_units if unit.block is not None and not unit.is_na
647-
]
623+
dtypes = [unit.dtype for unit in join_units if not unit.is_na]
648624
if not len(dtypes):
649-
dtypes = [
650-
unit.dtype
651-
for unit in join_units
652-
if unit.block is not None and unit.block.dtype.kind != "V"
653-
]
625+
dtypes = [unit.dtype for unit in join_units if unit.block.dtype.kind != "V"]
654626

655627
dtype = find_common_type(dtypes)
656628
if has_none_blocks:
@@ -666,7 +638,7 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
666638
667639
"""
668640
first = join_units[0].block
669-
if first is None or first.dtype.kind == "V":
641+
if first.dtype.kind == "V":
670642
return False
671643
return (
672644
# exclude cases where a) ju.block is None or b) we have e.g. Int64+int64
@@ -696,7 +668,7 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
696668
def _is_uniform_reindex(join_units) -> bool:
697669
return (
698670
# TODO: should this be ju.block._can_hold_na?
699-
all(ju.block and ju.block.is_extension for ju in join_units)
671+
all(ju.block.is_extension for ju in join_units)
700672
and len({ju.block.dtype.name for ju in join_units}) == 1
701673
)
702674

0 commit comments

Comments
 (0)