Skip to content

Commit 996bd11

Browse files
authored
REF/PERF: pd.concat (#52312)
* PERF: concat * REF/PERF: pd.concat * mypy fixup
1 parent c894b3f commit 996bd11

File tree

2 files changed

+132
-87
lines changed

2 files changed

+132
-87
lines changed

pandas/core/generic.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -9362,7 +9362,13 @@ def compare(
93629362
else:
93639363
axis = self._get_axis_number(align_axis)
93649364

9365-
diff = concat([self, other], axis=axis, keys=result_names)
9365+
# error: List item 0 has incompatible type "NDFrame"; expected
9366+
# "Union[Series, DataFrame]"
9367+
diff = concat(
9368+
[self, other], # type: ignore[list-item]
9369+
axis=axis,
9370+
keys=result_names,
9371+
)
93669372

93679373
if axis >= self.ndim:
93689374
# No need to reorganize data if stacking on new axis

pandas/core/reshape/concat.py

+125-86
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
DataFrame,
6060
Series,
6161
)
62-
from pandas.core.generic import NDFrame
6362

6463
# ---------------------------------------------------------------------
6564
# Concatenate DataFrame objects
@@ -101,7 +100,7 @@ def concat(
101100

102101
@overload
103102
def concat(
104-
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
103+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
105104
*,
106105
axis: Literal[0, "index"] = ...,
107106
join: str = ...,
@@ -118,7 +117,7 @@ def concat(
118117

119118
@overload
120119
def concat(
121-
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
120+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
122121
*,
123122
axis: Literal[1, "columns"],
124123
join: str = ...,
@@ -135,7 +134,7 @@ def concat(
135134

136135
@overload
137136
def concat(
138-
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
137+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
139138
*,
140139
axis: Axis = ...,
141140
join: str = ...,
@@ -151,7 +150,7 @@ def concat(
151150

152151

153152
def concat(
154-
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
153+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
155154
*,
156155
axis: Axis = 0,
157156
join: str = "outer",
@@ -398,7 +397,7 @@ class _Concatenator:
398397

399398
def __init__(
400399
self,
401-
objs: Iterable[NDFrame] | Mapping[HashableT, NDFrame],
400+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
402401
axis: Axis = 0,
403402
join: str = "outer",
404403
keys=None,
@@ -424,6 +423,72 @@ def __init__(
424423
"Only can inner (intersect) or outer (union) join the other axis"
425424
)
426425

426+
if not is_bool(sort):
427+
raise ValueError(
428+
f"The 'sort' keyword only accepts boolean values; {sort} was passed."
429+
)
430+
# Incompatible types in assignment (expression has type "Union[bool, bool_]",
431+
# variable has type "bool")
432+
self.sort = sort # type: ignore[assignment]
433+
434+
self.ignore_index = ignore_index
435+
self.verify_integrity = verify_integrity
436+
self.copy = copy
437+
438+
objs, keys = self._clean_keys_and_objs(objs, keys)
439+
440+
# figure out what our result ndim is going to be
441+
ndims = self._get_ndims(objs)
442+
sample, objs = self._get_sample_object(objs, ndims, keys, names, levels)
443+
444+
# Standardize axis parameter to int
445+
if sample.ndim == 1:
446+
from pandas import DataFrame
447+
448+
axis = DataFrame._get_axis_number(axis)
449+
self._is_frame = False
450+
self._is_series = True
451+
else:
452+
axis = sample._get_axis_number(axis)
453+
self._is_frame = True
454+
self._is_series = False
455+
456+
# Need to flip BlockManager axis in the DataFrame special case
457+
axis = sample._get_block_manager_axis(axis)
458+
459+
# if we have mixed ndims, then convert to highest ndim
460+
# creating column numbers as needed
461+
if len(ndims) > 1:
462+
objs, sample = self._sanitize_mixed_ndim(objs, sample, ignore_index, axis)
463+
464+
self.objs = objs
465+
466+
# note: this is the BlockManager axis (since DataFrame is transposed)
467+
self.bm_axis = axis
468+
self.axis = 1 - self.bm_axis if self._is_frame else 0
469+
self.keys = keys
470+
self.names = names or getattr(keys, "names", None)
471+
self.levels = levels
472+
473+
def _get_ndims(self, objs: list[Series | DataFrame]) -> set[int]:
474+
# figure out what our result ndim is going to be
475+
ndims = set()
476+
for obj in objs:
477+
if not isinstance(obj, (ABCSeries, ABCDataFrame)):
478+
msg = (
479+
f"cannot concatenate object of type '{type(obj)}'; "
480+
"only Series and DataFrame objs are valid"
481+
)
482+
raise TypeError(msg)
483+
484+
ndims.add(obj.ndim)
485+
return ndims
486+
487+
def _clean_keys_and_objs(
488+
self,
489+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
490+
keys,
491+
) -> tuple[list[Series | DataFrame], Index | None]:
427492
if isinstance(objs, abc.Mapping):
428493
if keys is None:
429494
keys = list(objs.keys())
@@ -437,7 +502,7 @@ def __init__(
437502
if keys is None:
438503
objs = list(com.not_none(*objs))
439504
else:
440-
# #1649
505+
# GH#1649
441506
clean_keys = []
442507
clean_objs = []
443508
if is_iterator(keys):
@@ -470,22 +535,20 @@ def __init__(
470535
if len(objs) == 0:
471536
raise ValueError("All objects passed were None")
472537

473-
# figure out what our result ndim is going to be
474-
ndims = set()
475-
for obj in objs:
476-
if not isinstance(obj, (ABCSeries, ABCDataFrame)):
477-
msg = (
478-
f"cannot concatenate object of type '{type(obj)}'; "
479-
"only Series and DataFrame objs are valid"
480-
)
481-
raise TypeError(msg)
482-
483-
ndims.add(obj.ndim)
538+
return objs, keys
484539

540+
def _get_sample_object(
541+
self,
542+
objs: list[Series | DataFrame],
543+
ndims: set[int],
544+
keys,
545+
names,
546+
levels,
547+
) -> tuple[Series | DataFrame, list[Series | DataFrame]]:
485548
# get the sample
486549
# want the highest ndim that we have, and must be non-empty
487550
# unless all objs are empty
488-
sample: NDFrame | None = None
551+
sample: Series | DataFrame | None = None
489552
if len(ndims) > 1:
490553
max_ndim = max(ndims)
491554
for obj in objs:
@@ -506,82 +569,48 @@ def __init__(
506569

507570
if sample is None:
508571
sample = objs[0]
509-
self.objs = objs
510-
511-
# Standardize axis parameter to int
512-
if sample.ndim == 1:
513-
from pandas import DataFrame
514-
515-
axis = DataFrame._get_axis_number(axis)
516-
self._is_frame = False
517-
self._is_series = True
518-
else:
519-
axis = sample._get_axis_number(axis)
520-
self._is_frame = True
521-
self._is_series = False
522-
523-
# Need to flip BlockManager axis in the DataFrame special case
524-
if self._is_frame:
525-
axis = sample._get_block_manager_axis(axis)
526-
527-
if not 0 <= axis <= sample.ndim:
528-
raise AssertionError(
529-
f"axis must be between 0 and {sample.ndim}, input was {axis}"
530-
)
572+
return sample, objs
531573

574+
def _sanitize_mixed_ndim(
575+
self,
576+
objs: list[Series | DataFrame],
577+
sample: Series | DataFrame,
578+
ignore_index: bool,
579+
axis: AxisInt,
580+
) -> tuple[list[Series | DataFrame], Series | DataFrame]:
532581
# if we have mixed ndims, then convert to highest ndim
533582
# creating column numbers as needed
534-
if len(ndims) > 1:
535-
current_column = 0
536-
max_ndim = sample.ndim
537-
self.objs, objs = [], self.objs
538-
for obj in objs:
539-
ndim = obj.ndim
540-
if ndim == max_ndim:
541-
pass
542583

543-
elif ndim != max_ndim - 1:
544-
raise ValueError(
545-
"cannot concatenate unaligned mixed "
546-
"dimensional NDFrame objects"
547-
)
584+
new_objs = []
548585

549-
else:
550-
name = getattr(obj, "name", None)
551-
if ignore_index or name is None:
552-
name = current_column
553-
current_column += 1
586+
current_column = 0
587+
max_ndim = sample.ndim
588+
for obj in objs:
589+
ndim = obj.ndim
590+
if ndim == max_ndim:
591+
pass
554592

555-
# doing a row-wise concatenation so need everything
556-
# to line up
557-
if self._is_frame and axis == 1:
558-
name = 0
559-
# mypy needs to know sample is not an NDFrame
560-
sample = cast("DataFrame | Series", sample)
561-
obj = sample._constructor({name: obj}, copy=False)
593+
elif ndim != max_ndim - 1:
594+
raise ValueError(
595+
"cannot concatenate unaligned mixed dimensional NDFrame objects"
596+
)
562597

563-
self.objs.append(obj)
598+
else:
599+
name = getattr(obj, "name", None)
600+
if ignore_index or name is None:
601+
name = current_column
602+
current_column += 1
564603

565-
# note: this is the BlockManager axis (since DataFrame is transposed)
566-
self.bm_axis = axis
567-
self.axis = 1 - self.bm_axis if self._is_frame else 0
568-
self.keys = keys
569-
self.names = names or getattr(keys, "names", None)
570-
self.levels = levels
604+
# doing a row-wise concatenation so need everything
605+
# to line up
606+
if self._is_frame and axis == 1:
607+
name = 0
571608

572-
if not is_bool(sort):
573-
raise ValueError(
574-
f"The 'sort' keyword only accepts boolean values; {sort} was passed."
575-
)
576-
# Incompatible types in assignment (expression has type "Union[bool, bool_]",
577-
# variable has type "bool")
578-
self.sort = sort # type: ignore[assignment]
609+
obj = sample._constructor({name: obj}, copy=False)
579610

580-
self.ignore_index = ignore_index
581-
self.verify_integrity = verify_integrity
582-
self.copy = copy
611+
new_objs.append(obj)
583612

584-
self.new_axes = self._get_new_axes()
613+
return new_objs, sample
585614

586615
def get_result(self):
587616
cons: Callable[..., DataFrame | Series]
@@ -599,7 +628,16 @@ def get_result(self):
599628
arrs = [ser._values for ser in self.objs]
600629

601630
res = concat_compat(arrs, axis=0)
602-
mgr = type(sample._mgr).from_array(res, index=self.new_axes[0])
631+
632+
new_index: Index
633+
if self.ignore_index:
634+
# We can avoid surprisingly-expensive _get_concat_axis
635+
new_index = default_index(len(res))
636+
else:
637+
new_index = self.new_axes[0]
638+
639+
mgr = type(sample._mgr).from_array(res, index=new_index)
640+
603641
result = cons(mgr, name=name, fastpath=True)
604642
return result.__finalize__(self, method="concat")
605643

@@ -650,7 +688,8 @@ def _get_result_dim(self) -> int:
650688
else:
651689
return self.objs[0].ndim
652690

653-
def _get_new_axes(self) -> list[Index]:
691+
@cache_readonly
692+
def new_axes(self) -> list[Index]:
654693
ndim = self._get_result_dim()
655694
return [
656695
self._get_concat_axis if i == self.bm_axis else self._get_comb_axis(i)

0 commit comments

Comments
 (0)