Skip to content

Revert "REF: remove JoinUnit.shape (#43651)" #47406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions pandas/core/internals/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ def concatenate_managers(
for placement, join_units in concat_plan:
unit = join_units[0]
blk = unit.block
# Assertion disabled for performance
# assert len(join_units) == len(mgrs_indexers)

if len(join_units) == 1:
values = blk.values
Expand Down Expand Up @@ -331,20 +329,27 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
plan : list of (BlockPlacement, JoinUnit) tuples

"""
# Calculate post-reindex shape , save for item axis which will be separate
# for each block anyway.
mgr_shape_list = list(mgr.shape)
mgr_shape = tuple(mgr_shape_list)

if mgr.is_single_block:
blk = mgr.blocks[0]
return [(blk.mgr_locs, JoinUnit(blk))]
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape))]

blknos = mgr.blknos
blklocs = mgr.blklocs

plan = []
for blkno, placements in libinternals.get_blkno_placements(blknos, group=False):

# Assertions disabled for performance; these should always hold
# assert placements.is_slice_like
# assert blkno != -1
assert placements.is_slice_like
assert blkno != -1

shape_list = list(mgr_shape)
shape_list[0] = len(placements)
shape = tuple(shape_list)

blk = mgr.blocks[blkno]
ax0_blk_indexer = blklocs[placements.indexer]
Expand Down Expand Up @@ -374,16 +379,19 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):

# Assertions disabled for performance
# assert blk._mgr_locs.as_slice == placements.as_slice
unit = JoinUnit(blk)
# assert blk.shape[0] == shape[0]
unit = JoinUnit(blk, shape)

plan.append((placements, unit))

return plan


class JoinUnit:
def __init__(self, block: Block) -> None:
def __init__(self, block: Block, shape: Shape) -> None:
# Passing shape explicitly is required for cases when block is None.
self.block = block
self.shape = shape

def __repr__(self) -> str:
return f"{type(self).__name__}({repr(self.block)})"
Expand All @@ -396,11 +404,22 @@ def is_na(self) -> bool:
return False

def get_reindexed_values(self, empty_dtype: DtypeObj) -> ArrayLike:
values: ArrayLike

if self.is_na:
return make_na_array(empty_dtype, self.block.shape)
return make_na_array(empty_dtype, self.shape)

else:
return self.block.values

if not self.block._can_consolidate:
# preserve these for validation in concat_compat
return self.block.values

# No dtype upcasting is done here, it will be performed during
# concatenation itself.
values = self.block.values

return values


def make_na_array(dtype: DtypeObj, shape: Shape) -> ArrayLike:
Expand Down Expand Up @@ -539,9 +558,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
first = join_units[0].block
if first.dtype.kind == "V":
return False
elif len(join_units) == 1:
# only use this path when there is something to concatenate
return False
return (
# exclude cases where a) ju.block is None or b) we have e.g. Int64+int64
all(type(ju.block) is type(first) for ju in join_units)
Expand All @@ -554,8 +570,13 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
or ju.block.dtype.kind in ["b", "i", "u"]
for ju in join_units
)
# this also precludes any blocks with dtype.kind == "V", since
# we excluded that case for `first` above.
and
# no blocks that would get missing values (can lead to type upcasts)
# unless we're an extension dtype.
all(not ju.is_na or ju.block.is_extension for ju in join_units)
and
# only use this path when there is something to concatenate
len(join_units) > 1
)


Expand All @@ -577,7 +598,10 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
extra_block = join_unit.block.getitem_block(slice(length, None))
join_unit.block = join_unit.block.getitem_block(slice(length))

return JoinUnit(block=extra_block)
extra_shape = (join_unit.shape[0] - length,) + join_unit.shape[1:]
join_unit.shape = (length,) + join_unit.shape[1:]

return JoinUnit(block=extra_block, shape=extra_shape)


def _combine_concat_plans(plans):
Expand Down