Skip to content

Commit 0cb4d1b

Browse files
authored
REF: slice before constructing JoinUnit (#52542)
PERF: slice before constructing JoinUnit
1 parent 8b9600c commit 0cb4d1b

File tree

1 file changed

+38
-93
lines changed

1 file changed

+38
-93
lines changed

pandas/core/internals/concat.py

+38-93
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import copy as cp
43
import itertools
54
from typing import (
65
TYPE_CHECKING,
@@ -40,7 +39,6 @@
4039
isna_all,
4140
)
4241

43-
import pandas.core.algorithms as algos
4442
from pandas.core.arrays import (
4543
DatetimeArray,
4644
ExtensionArray,
@@ -63,7 +61,6 @@
6361
AxisInt,
6462
DtypeObj,
6563
Manager,
66-
Shape,
6764
)
6865

6966
from pandas import Index
@@ -206,17 +203,15 @@ def concatenate_managers(
206203

207204
mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)
208205

209-
concat_plans = [
210-
_get_mgr_concatenation_plan(mgr, indexers) for mgr, indexers in mgrs_indexers
211-
]
206+
concat_plans = [_get_mgr_concatenation_plan(mgr) for mgr, _ in mgrs_indexers]
212207
concat_plan = _combine_concat_plans(concat_plans)
213208
blocks = []
214209

215210
for placement, join_units in concat_plan:
216211
unit = join_units[0]
217212
blk = unit.block
218213

219-
if len(join_units) == 1 and not join_units[0].indexers:
214+
if len(join_units) == 1:
220215
values = blk.values
221216
if copy:
222217
values = values.copy()
@@ -322,34 +317,21 @@ def _maybe_reindex_columns_na_proxy(
322317
return new_mgrs_indexers
323318

324319

325-
def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarray]):
320+
def _get_mgr_concatenation_plan(mgr: BlockManager):
326321
"""
327-
Construct concatenation plan for given block manager and indexers.
322+
Construct concatenation plan for given block manager.
328323
329324
Parameters
330325
----------
331326
mgr : BlockManager
332-
indexers : dict of {axis: indexer}
333327
334328
Returns
335329
-------
336330
plan : list of (BlockPlacement, JoinUnit) tuples
337-
338331
"""
339-
assert len(indexers) == 0
340-
341-
# Calculate post-reindex shape, save for item axis which will be separate
342-
# for each block anyway.
343-
mgr_shape_list = list(mgr.shape)
344-
for ax, indexer in indexers.items():
345-
mgr_shape_list[ax] = len(indexer)
346-
mgr_shape = tuple(mgr_shape_list)
347-
348-
assert 0 not in indexers
349-
350332
if mgr.is_single_block:
351333
blk = mgr.blocks[0]
352-
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape, indexers))]
334+
return [(blk.mgr_locs, JoinUnit(blk))]
353335

354336
blknos = mgr.blknos
355337
blklocs = mgr.blklocs
@@ -359,12 +341,6 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
359341
assert placements.is_slice_like
360342
assert blkno != -1
361343

362-
join_unit_indexers = indexers.copy()
363-
364-
shape_list = list(mgr_shape)
365-
shape_list[0] = len(placements)
366-
shape = tuple(shape_list)
367-
368344
blk = mgr.blocks[blkno]
369345
ax0_blk_indexer = blklocs[placements.indexer]
370346

@@ -381,37 +357,34 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
381357
# Slow-ish detection: all indexer locs
382358
# are sequential (and length match is
383359
# checked above).
360+
# TODO: check unnecessary? unique_deltas?
361+
# can we shortcut other is_slice_like cases?
384362
(np.diff(ax0_blk_indexer) == 1).all()
385363
)
386364
)
387365

388366
# Omit indexer if no item reindexing is required.
389-
if unit_no_ax0_reindexing:
390-
join_unit_indexers.pop(0, None)
391-
else:
392-
join_unit_indexers[0] = ax0_blk_indexer
367+
if not unit_no_ax0_reindexing:
368+
# TODO: better max_len?
369+
max_len = max(len(ax0_blk_indexer), ax0_blk_indexer.max() + 1)
370+
slc = lib.maybe_indices_to_slice(ax0_blk_indexer, max_len)
371+
# TODO: in all extant test cases 2023-04-08 we have a slice here.
372+
# Will this always be the case?
373+
blk = blk.getitem_block(slc)
393374

394-
unit = JoinUnit(blk, shape, join_unit_indexers)
375+
unit = JoinUnit(blk)
395376

396377
plan.append((placements, unit))
397378

398379
return plan
399380

400381

401382
class JoinUnit:
402-
def __init__(self, block: Block, shape: Shape, indexers=None) -> None:
403-
# Passing shape explicitly is required for cases when block is None.
404-
# Note: block is None implies indexers is None, but not vice-versa
405-
if indexers is None:
406-
indexers = {}
407-
# Otherwise we may have only {0: np.array(...)} and only non-negative
408-
# entries.
383+
def __init__(self, block: Block) -> None:
409384
self.block = block
410-
self.indexers = indexers
411-
self.shape = shape
412385

413386
def __repr__(self) -> str:
414-
return f"{type(self).__name__}({repr(self.block)}, {self.indexers})"
387+
return f"{type(self).__name__}({repr(self.block)})"
415388

416389
def _is_valid_na_for(self, dtype: DtypeObj) -> bool:
417390
"""
@@ -498,43 +471,38 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
498471

499472
if isinstance(empty_dtype, DatetimeTZDtype):
500473
# NB: exclude e.g. pyarrow[dt64tz] dtypes
501-
i8values = np.full(self.shape, fill_value._value)
474+
i8values = np.full(self.block.shape, fill_value._value)
502475
return DatetimeArray(i8values, dtype=empty_dtype)
503476

504477
elif is_1d_only_ea_dtype(empty_dtype):
505-
if is_dtype_equal(blk_dtype, empty_dtype) and self.indexers:
506-
# avoid creating new empty array if we already have an array
507-
# with correct dtype that can be reindexed
508-
pass
509-
else:
510-
empty_dtype = cast(ExtensionDtype, empty_dtype)
511-
cls = empty_dtype.construct_array_type()
512-
513-
missing_arr = cls._from_sequence([], dtype=empty_dtype)
514-
ncols, nrows = self.shape
515-
assert ncols == 1, ncols
516-
empty_arr = -1 * np.ones((nrows,), dtype=np.intp)
517-
return missing_arr.take(
518-
empty_arr, allow_fill=True, fill_value=fill_value
519-
)
478+
empty_dtype = cast(ExtensionDtype, empty_dtype)
479+
cls = empty_dtype.construct_array_type()
480+
481+
missing_arr = cls._from_sequence([], dtype=empty_dtype)
482+
ncols, nrows = self.block.shape
483+
assert ncols == 1, ncols
484+
empty_arr = -1 * np.ones((nrows,), dtype=np.intp)
485+
return missing_arr.take(
486+
empty_arr, allow_fill=True, fill_value=fill_value
487+
)
520488
elif isinstance(empty_dtype, ExtensionDtype):
521489
# TODO: no tests get here, a handful would if we disabled
522490
# the dt64tz special-case above (which is faster)
523491
cls = empty_dtype.construct_array_type()
524-
missing_arr = cls._empty(shape=self.shape, dtype=empty_dtype)
492+
missing_arr = cls._empty(shape=self.block.shape, dtype=empty_dtype)
525493
missing_arr[:] = fill_value
526494
return missing_arr
527495
else:
528496
# NB: we should never get here with empty_dtype integer or bool;
529497
# if we did, the missing_arr.fill would cast to gibberish
530-
missing_arr = np.empty(self.shape, dtype=empty_dtype)
498+
missing_arr = np.empty(self.block.shape, dtype=empty_dtype)
531499
missing_arr.fill(fill_value)
532500

533501
if empty_dtype.kind in "mM":
534502
missing_arr = ensure_wrapped_if_datetimelike(missing_arr)
535503
return missing_arr
536504

537-
if (not self.indexers) and (not self.block._can_consolidate):
505+
if not self.block._can_consolidate:
538506
# preserve these for validation in concat_compat
539507
return self.block.values
540508

@@ -547,16 +515,10 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
547515
# concatenation itself.
548516
values = self.block.values
549517

550-
if not self.indexers:
551-
# If there's no indexing to be done, we want to signal outside
552-
# code that this array must be copied explicitly. This is done
553-
# by returning a view and checking `retval.base`.
554-
values = values.view()
555-
556-
else:
557-
for ax, indexer in self.indexers.items():
558-
values = algos.take_nd(values, indexer, axis=ax)
559-
518+
# If there's no indexing to be done, we want to signal outside
519+
# code that this array must be copied explicitly. This is done
520+
# by returning a view and checking `retval.base`.
521+
values = values.view()
560522
return values
561523

562524

@@ -688,9 +650,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
688650
# unless we're an extension dtype.
689651
all(not ju.is_na or ju.block.is_extension for ju in join_units)
690652
and
691-
# no blocks with indexers (as then the dimensions do not fit)
692-
all(not ju.indexers for ju in join_units)
693-
and
694653
# only use this path when there is something to concatenate
695654
len(join_units) > 1
696655
)
@@ -710,25 +669,11 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
710669
711670
Extra items that didn't fit are returned as a separate block.
712671
"""
713-
if 0 not in join_unit.indexers:
714-
extra_indexers = join_unit.indexers
715-
716-
if join_unit.block is None:
717-
extra_block = None
718-
else:
719-
extra_block = join_unit.block.getitem_block(slice(length, None))
720-
join_unit.block = join_unit.block.getitem_block(slice(length))
721-
else:
722-
extra_block = join_unit.block
723-
724-
extra_indexers = cp.copy(join_unit.indexers)
725-
extra_indexers[0] = extra_indexers[0][length:]
726-
join_unit.indexers[0] = join_unit.indexers[0][:length]
727672

728-
extra_shape = (join_unit.shape[0] - length,) + join_unit.shape[1:]
729-
join_unit.shape = (length,) + join_unit.shape[1:]
673+
extra_block = join_unit.block.getitem_block(slice(length, None))
674+
join_unit.block = join_unit.block.getitem_block(slice(length))
730675

731-
return JoinUnit(block=extra_block, indexers=extra_indexers, shape=extra_shape)
676+
return JoinUnit(block=extra_block)
732677

733678

734679
def _combine_concat_plans(plans):

0 commit comments

Comments
 (0)