Skip to content

Commit b0231a6

Browse files
Revert "REF: concat on bm_axis==0 (pandas-dev#43626)"
This reverts commit 0de6f8b.
1 parent 90e966e commit b0231a6

File tree

1 file changed

+100
-77
lines changed

1 file changed

+100
-77
lines changed

pandas/core/internals/concat.py

+100-77
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from pandas.core.dtypes.dtypes import ExtensionDtype
3838

39+
import pandas.core.algorithms as algos
3940
from pandas.core.arrays import (
4041
DatetimeArray,
4142
ExtensionArray,
@@ -191,29 +192,19 @@ def concatenate_managers(
191192
if isinstance(mgrs_indexers[0][0], ArrayManager):
192193
return _concatenate_array_managers(mgrs_indexers, axes, concat_axis, copy)
193194

194-
# Assertions disabled for performance
195-
# for tup in mgrs_indexers:
196-
# # caller is responsible for ensuring this
197-
# indexers = tup[1]
198-
# assert concat_axis not in indexers
199-
200-
if concat_axis == 0:
201-
return _concat_managers_axis0(mgrs_indexers, axes, copy)
202-
203195
mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)
204196

205-
# Assertion disabled for performance
206-
# assert all(not x[1] for x in mgrs_indexers)
207-
208-
concat_plans = [_get_mgr_concatenation_plan(mgr) for mgr, _ in mgrs_indexers]
209-
concat_plan = _combine_concat_plans(concat_plans)
197+
concat_plans = [
198+
_get_mgr_concatenation_plan(mgr, indexers) for mgr, indexers in mgrs_indexers
199+
]
200+
concat_plan = _combine_concat_plans(concat_plans, concat_axis)
210201
blocks = []
211202

212203
for placement, join_units in concat_plan:
213204
unit = join_units[0]
214205
blk = unit.block
215206

216-
if len(join_units) == 1:
207+
if len(join_units) == 1 and not join_units[0].indexers:
217208
values = blk.values
218209
if copy:
219210
values = values.copy()
@@ -237,7 +228,7 @@ def concatenate_managers(
237228

238229
fastpath = blk.values.dtype == values.dtype
239230
else:
240-
values = _concatenate_join_units(join_units, copy=copy)
231+
values = _concatenate_join_units(join_units, concat_axis, copy=copy)
241232
fastpath = False
242233

243234
if fastpath:
@@ -250,42 +241,6 @@ def concatenate_managers(
250241
return BlockManager(tuple(blocks), axes)
251242

252243

253-
def _concat_managers_axis0(
254-
mgrs_indexers, axes: list[Index], copy: bool
255-
) -> BlockManager:
256-
"""
257-
concat_managers specialized to concat_axis=0, with reindexing already
258-
having been done in _maybe_reindex_columns_na_proxy.
259-
"""
260-
had_reindexers = {
261-
i: len(mgrs_indexers[i][1]) > 0 for i in range(len(mgrs_indexers))
262-
}
263-
mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)
264-
265-
mgrs = [x[0] for x in mgrs_indexers]
266-
267-
offset = 0
268-
blocks = []
269-
for i, mgr in enumerate(mgrs):
270-
# If we already reindexed, then we definitely don't need another copy
271-
made_copy = had_reindexers[i]
272-
273-
for blk in mgr.blocks:
274-
if made_copy:
275-
nb = blk.copy(deep=False)
276-
elif copy:
277-
nb = blk.copy()
278-
else:
279-
# by slicing instead of copy(deep=False), we get a new array
280-
# object, see test_concat_copy
281-
nb = blk.getitem_block(slice(None))
282-
nb._mgr_locs = nb._mgr_locs.add(offset)
283-
blocks.append(nb)
284-
285-
offset += len(mgr.items)
286-
return BlockManager(tuple(blocks), axes)
287-
288-
289244
def _maybe_reindex_columns_na_proxy(
290245
axes: list[Index], mgrs_indexers: list[tuple[BlockManager, dict[int, np.ndarray]]]
291246
) -> list[tuple[BlockManager, dict[int, np.ndarray]]]:
@@ -296,33 +251,36 @@ def _maybe_reindex_columns_na_proxy(
296251
Columns added in this reindexing have dtype=np.void, indicating they
297252
should be ignored when choosing a column's final dtype.
298253
"""
299-
new_mgrs_indexers: list[tuple[BlockManager, dict[int, np.ndarray]]] = []
300-
254+
new_mgrs_indexers = []
301255
for mgr, indexers in mgrs_indexers:
302-
# For axis=0 (i.e. columns) we use_na_proxy and only_slice, so this
303-
# is a cheap reindexing.
304-
for i, indexer in indexers.items():
305-
mgr = mgr.reindex_indexer(
306-
axes[i],
307-
indexers[i],
308-
axis=i,
256+
# We only reindex for axis=0 (i.e. columns), as this can be done cheaply
257+
if 0 in indexers:
258+
new_mgr = mgr.reindex_indexer(
259+
axes[0],
260+
indexers[0],
261+
axis=0,
309262
copy=False,
310-
only_slice=True, # only relevant for i==0
263+
only_slice=True,
311264
allow_dups=True,
312-
use_na_proxy=True, # only relevant for i==0
265+
use_na_proxy=True,
313266
)
314-
new_mgrs_indexers.append((mgr, {}))
267+
new_indexers = indexers.copy()
268+
del new_indexers[0]
269+
new_mgrs_indexers.append((new_mgr, new_indexers))
270+
else:
271+
new_mgrs_indexers.append((mgr, indexers))
315272

316273
return new_mgrs_indexers
317274

318275

319-
def _get_mgr_concatenation_plan(mgr: BlockManager):
276+
def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarray]):
320277
"""
321-
Construct concatenation plan for given block manager.
278+
Construct concatenation plan for given block manager and indexers.
322279
323280
Parameters
324281
----------
325282
mgr : BlockManager
283+
indexers : dict of {axis: indexer}
326284
327285
Returns
328286
-------
@@ -332,11 +290,27 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
332290
# Calculate post-reindex shape , save for item axis which will be separate
333291
# for each block anyway.
334292
mgr_shape_list = list(mgr.shape)
293+
for ax, indexer in indexers.items():
294+
mgr_shape_list[ax] = len(indexer)
335295
mgr_shape = tuple(mgr_shape_list)
336296

297+
assert 0 not in indexers
298+
299+
needs_filling = False
300+
if 1 in indexers:
301+
# indexers[1] is shared by all the JoinUnits, so we can save time
302+
# by only doing this check once
303+
if (indexers[1] == -1).any():
304+
needs_filling = True
305+
337306
if mgr.is_single_block:
338307
blk = mgr.blocks[0]
339-
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape))]
308+
return [
309+
(
310+
blk.mgr_locs,
311+
JoinUnit(blk, mgr_shape, indexers, needs_filling=needs_filling),
312+
)
313+
]
340314

341315
blknos = mgr.blknos
342316
blklocs = mgr.blklocs
@@ -347,6 +321,8 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
347321
assert placements.is_slice_like
348322
assert blkno != -1
349323

324+
join_unit_indexers = indexers.copy()
325+
350326
shape_list = list(mgr_shape)
351327
shape_list[0] = len(placements)
352328
shape = tuple(shape_list)
@@ -380,21 +356,30 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
380356
# Assertions disabled for performance
381357
# assert blk._mgr_locs.as_slice == placements.as_slice
382358
# assert blk.shape[0] == shape[0]
383-
unit = JoinUnit(blk, shape)
359+
unit = JoinUnit(blk, shape, join_unit_indexers, needs_filling=needs_filling)
384360

385361
plan.append((placements, unit))
386362

387363
return plan
388364

389365

390366
class JoinUnit:
391-
def __init__(self, block: Block, shape: Shape):
367+
def __init__(
368+
self, block: Block, shape: Shape, indexers=None, *, needs_filling: bool = False
369+
):
392370
# Passing shape explicitly is required for cases when block is None.
371+
# Note: block is None implies indexers is None, but not vice-versa
372+
if indexers is None:
373+
indexers = {}
374+
# we should *never* have `0 in indexers`
393375
self.block = block
376+
self.indexers = indexers
394377
self.shape = shape
395378

379+
self.needs_filling = needs_filling
380+
396381
def __repr__(self) -> str:
397-
return f"{type(self).__name__}({repr(self.block)})"
382+
return f"{type(self).__name__}({repr(self.block)}, {self.indexers})"
398383

399384
@cache_readonly
400385
def is_na(self) -> bool:
@@ -411,14 +396,24 @@ def get_reindexed_values(self, empty_dtype: DtypeObj) -> ArrayLike:
411396

412397
else:
413398

414-
if not self.block._can_consolidate:
399+
if (not self.indexers) and (not self.block._can_consolidate):
415400
# preserve these for validation in concat_compat
416401
return self.block.values
417402

418403
# No dtype upcasting is done here, it will be performed during
419404
# concatenation itself.
420405
values = self.block.values
421406

407+
if not self.indexers:
408+
# If there's no indexing to be done, we want to signal outside
409+
# code that this array must be copied explicitly. This is done
410+
# by returning a view and checking `retval.base`.
411+
values = values.view()
412+
413+
else:
414+
for ax, indexer in self.indexers.items():
415+
values = algos.take_nd(values, indexer, axis=ax)
416+
422417
return values
423418

424419

@@ -456,10 +451,15 @@ def make_na_array(dtype: DtypeObj, shape: Shape) -> ArrayLike:
456451
return missing_arr
457452

458453

459-
def _concatenate_join_units(join_units: list[JoinUnit], copy: bool) -> ArrayLike:
454+
def _concatenate_join_units(
455+
join_units: list[JoinUnit], concat_axis: int, copy: bool
456+
) -> ArrayLike:
460457
"""
461-
Concatenate values from several join units along axis=1.
458+
Concatenate values from several join units along selected axis.
462459
"""
460+
if concat_axis == 0 and len(join_units) > 1:
461+
# Concatenating join units along ax0 is handled in _merge_blocks.
462+
raise AssertionError("Concatenating join units along axis0")
463463

464464
empty_dtype = _get_empty_dtype(join_units)
465465

@@ -495,7 +495,7 @@ def _concatenate_join_units(join_units: list[JoinUnit], copy: bool) -> ArrayLike
495495
concat_values = ensure_block_shape(concat_values, 2)
496496

497497
else:
498-
concat_values = concat_compat(to_concat, axis=1)
498+
concat_values = concat_compat(to_concat, axis=concat_axis)
499499

500500
return concat_values
501501

@@ -538,7 +538,7 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
538538
empty_dtype = join_units[0].block.dtype
539539
return empty_dtype
540540

541-
needs_can_hold_na = any(unit.is_na for unit in join_units)
541+
needs_can_hold_na = any(unit.is_na or unit.needs_filling for unit in join_units)
542542

543543
dtypes = [unit.block.dtype for unit in join_units if not unit.is_na]
544544

@@ -575,6 +575,9 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
575575
# unless we're an extension dtype.
576576
all(not ju.is_na or ju.block.is_extension for ju in join_units)
577577
and
578+
# no blocks with indexers (as then the dimensions do not fit)
579+
all(not ju.indexers for ju in join_units)
580+
and
578581
# only use this path when there is something to concatenate
579582
len(join_units) > 1
580583
)
@@ -594,17 +597,25 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
594597
595598
Extra items that didn't fit are returned as a separate block.
596599
"""
600+
assert 0 not in join_unit.indexers
601+
extra_indexers = join_unit.indexers
597602

598603
extra_block = join_unit.block.getitem_block(slice(length, None))
599604
join_unit.block = join_unit.block.getitem_block(slice(length))
600605

601606
extra_shape = (join_unit.shape[0] - length,) + join_unit.shape[1:]
602607
join_unit.shape = (length,) + join_unit.shape[1:]
603608

604-
return JoinUnit(block=extra_block, shape=extra_shape)
609+
# extra_indexers does not introduce any -1s, so we can inherit needs_filling
610+
return JoinUnit(
611+
block=extra_block,
612+
indexers=extra_indexers,
613+
shape=extra_shape,
614+
needs_filling=join_unit.needs_filling,
615+
)
605616

606617

607-
def _combine_concat_plans(plans):
618+
def _combine_concat_plans(plans, concat_axis: int):
608619
"""
609620
Combine multiple concatenation plans into one.
610621
@@ -614,6 +625,18 @@ def _combine_concat_plans(plans):
614625
for p in plans[0]:
615626
yield p[0], [p[1]]
616627

628+
elif concat_axis == 0:
629+
offset = 0
630+
for plan in plans:
631+
last_plc = None
632+
633+
for plc, unit in plan:
634+
yield plc.add(offset), [unit]
635+
last_plc = plc
636+
637+
if last_plc is not None:
638+
offset += last_plc.as_slice.stop
639+
617640
else:
618641
# singleton list so we can modify it as a side-effect within _next_or_none
619642
num_ended = [0]

0 commit comments

Comments
 (0)