Skip to content

REF: slice before constructing JoinUnit #52542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 38 additions & 93 deletions pandas/core/internals/concat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy as cp
import itertools
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -40,7 +39,6 @@
isna_all,
)

import pandas.core.algorithms as algos
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
Expand All @@ -63,7 +61,6 @@
AxisInt,
DtypeObj,
Manager,
Shape,
)

from pandas import Index
Expand Down Expand Up @@ -206,17 +203,15 @@ def concatenate_managers(

mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)

concat_plans = [
_get_mgr_concatenation_plan(mgr, indexers) for mgr, indexers in mgrs_indexers
]
concat_plans = [_get_mgr_concatenation_plan(mgr) for mgr, _ in mgrs_indexers]
concat_plan = _combine_concat_plans(concat_plans)
blocks = []

for placement, join_units in concat_plan:
unit = join_units[0]
blk = unit.block

if len(join_units) == 1 and not join_units[0].indexers:
if len(join_units) == 1:
values = blk.values
if copy:
values = values.copy()
Expand Down Expand Up @@ -322,34 +317,21 @@ def _maybe_reindex_columns_na_proxy(
return new_mgrs_indexers


def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarray]):
def _get_mgr_concatenation_plan(mgr: BlockManager):
"""
Construct concatenation plan for given block manager and indexers.
Construct concatenation plan for given block manager.

Parameters
----------
mgr : BlockManager
indexers : dict of {axis: indexer}

Returns
-------
plan : list of (BlockPlacement, JoinUnit) tuples

"""
assert len(indexers) == 0

# Calculate post-reindex shape, save for item axis which will be separate
# for each block anyway.
mgr_shape_list = list(mgr.shape)
for ax, indexer in indexers.items():
mgr_shape_list[ax] = len(indexer)
mgr_shape = tuple(mgr_shape_list)

assert 0 not in indexers

if mgr.is_single_block:
blk = mgr.blocks[0]
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape, indexers))]
return [(blk.mgr_locs, JoinUnit(blk))]

blknos = mgr.blknos
blklocs = mgr.blklocs
Expand All @@ -359,12 +341,6 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
assert placements.is_slice_like
assert blkno != -1

join_unit_indexers = indexers.copy()

shape_list = list(mgr_shape)
shape_list[0] = len(placements)
shape = tuple(shape_list)

blk = mgr.blocks[blkno]
ax0_blk_indexer = blklocs[placements.indexer]

Expand All @@ -381,37 +357,34 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
# Slow-ish detection: all indexer locs
# are sequential (and length match is
# checked above).
# TODO: check unnecessary? unique_deltas?
# can we shortcut other is_slice_like cases?
(np.diff(ax0_blk_indexer) == 1).all()
)
)

# Omit indexer if no item reindexing is required.
if unit_no_ax0_reindexing:
join_unit_indexers.pop(0, None)
else:
join_unit_indexers[0] = ax0_blk_indexer
if not unit_no_ax0_reindexing:
# TODO: better max_len?
max_len = max(len(ax0_blk_indexer), ax0_blk_indexer.max() + 1)
slc = lib.maybe_indices_to_slice(ax0_blk_indexer, max_len)
# TODO: in all extant test cases 2023-04-08 we have a slice here.
# Will this always be the case?
blk = blk.getitem_block(slc)

unit = JoinUnit(blk, shape, join_unit_indexers)
unit = JoinUnit(blk)

plan.append((placements, unit))

return plan


class JoinUnit:
def __init__(self, block: Block, shape: Shape, indexers=None) -> None:
# Passing shape explicitly is required for cases when block is None.
# Note: block is None implies indexers is None, but not vice-versa
if indexers is None:
indexers = {}
# Otherwise we may have only {0: np.array(...)} and only non-negative
# entries.
def __init__(self, block: Block) -> None:
self.block = block
self.indexers = indexers
self.shape = shape

def __repr__(self) -> str:
return f"{type(self).__name__}({repr(self.block)}, {self.indexers})"
return f"{type(self).__name__}({repr(self.block)})"

def _is_valid_na_for(self, dtype: DtypeObj) -> bool:
"""
Expand Down Expand Up @@ -498,43 +471,38 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:

if isinstance(empty_dtype, DatetimeTZDtype):
# NB: exclude e.g. pyarrow[dt64tz] dtypes
i8values = np.full(self.shape, fill_value._value)
i8values = np.full(self.block.shape, fill_value._value)
return DatetimeArray(i8values, dtype=empty_dtype)

elif is_1d_only_ea_dtype(empty_dtype):
if is_dtype_equal(blk_dtype, empty_dtype) and self.indexers:
# avoid creating new empty array if we already have an array
# with correct dtype that can be reindexed
pass
else:
empty_dtype = cast(ExtensionDtype, empty_dtype)
cls = empty_dtype.construct_array_type()

missing_arr = cls._from_sequence([], dtype=empty_dtype)
ncols, nrows = self.shape
assert ncols == 1, ncols
empty_arr = -1 * np.ones((nrows,), dtype=np.intp)
return missing_arr.take(
empty_arr, allow_fill=True, fill_value=fill_value
)
empty_dtype = cast(ExtensionDtype, empty_dtype)
cls = empty_dtype.construct_array_type()

missing_arr = cls._from_sequence([], dtype=empty_dtype)
ncols, nrows = self.block.shape
assert ncols == 1, ncols
empty_arr = -1 * np.ones((nrows,), dtype=np.intp)
return missing_arr.take(
empty_arr, allow_fill=True, fill_value=fill_value
)
elif isinstance(empty_dtype, ExtensionDtype):
# TODO: no tests get here, a handful would if we disabled
# the dt64tz special-case above (which is faster)
cls = empty_dtype.construct_array_type()
missing_arr = cls._empty(shape=self.shape, dtype=empty_dtype)
missing_arr = cls._empty(shape=self.block.shape, dtype=empty_dtype)
missing_arr[:] = fill_value
return missing_arr
else:
# NB: we should never get here with empty_dtype integer or bool;
# if we did, the missing_arr.fill would cast to gibberish
missing_arr = np.empty(self.shape, dtype=empty_dtype)
missing_arr = np.empty(self.block.shape, dtype=empty_dtype)
missing_arr.fill(fill_value)

if empty_dtype.kind in "mM":
missing_arr = ensure_wrapped_if_datetimelike(missing_arr)
return missing_arr

if (not self.indexers) and (not self.block._can_consolidate):
if not self.block._can_consolidate:
# preserve these for validation in concat_compat
return self.block.values

Expand All @@ -547,16 +515,10 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
# concatenation itself.
values = self.block.values

if not self.indexers:
# If there's no indexing to be done, we want to signal outside
# code that this array must be copied explicitly. This is done
# by returning a view and checking `retval.base`.
values = values.view()

else:
for ax, indexer in self.indexers.items():
values = algos.take_nd(values, indexer, axis=ax)

# If there's no indexing to be done, we want to signal outside
# code that this array must be copied explicitly. This is done
# by returning a view and checking `retval.base`.
values = values.view()
return values


Expand Down Expand Up @@ -688,9 +650,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
# unless we're an extension dtype.
all(not ju.is_na or ju.block.is_extension for ju in join_units)
and
# no blocks with indexers (as then the dimensions do not fit)
all(not ju.indexers for ju in join_units)
and
# only use this path when there is something to concatenate
len(join_units) > 1
)
Expand All @@ -710,25 +669,11 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:

Extra items that didn't fit are returned as a separate block.
"""
if 0 not in join_unit.indexers:
extra_indexers = join_unit.indexers

if join_unit.block is None:
extra_block = None
else:
extra_block = join_unit.block.getitem_block(slice(length, None))
join_unit.block = join_unit.block.getitem_block(slice(length))
else:
extra_block = join_unit.block

extra_indexers = cp.copy(join_unit.indexers)
extra_indexers[0] = extra_indexers[0][length:]
join_unit.indexers[0] = join_unit.indexers[0][:length]

extra_shape = (join_unit.shape[0] - length,) + join_unit.shape[1:]
join_unit.shape = (length,) + join_unit.shape[1:]
extra_block = join_unit.block.getitem_block(slice(length, None))
join_unit.block = join_unit.block.getitem_block(slice(length))

return JoinUnit(block=extra_block, indexers=extra_indexers, shape=extra_shape)
return JoinUnit(block=extra_block)


def _combine_concat_plans(plans):
Expand Down