Skip to content

Commit b62259c

Browse files
mroeschkepmhatre1
authored andcommitted
REF: Clean up concat statefullness and validation (pandas-dev#57933)
* REF: Clean up concat statefullness and validation * Use DataFrame again * Ignore false positive mypy?
1 parent 872c524 commit b62259c

File tree

1 file changed

+110
-119
lines changed

1 file changed

+110
-119
lines changed

pandas/core/reshape/concat.py

+110-119
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717

1818
from pandas.util._decorators import cache_readonly
1919

20-
from pandas.core.dtypes.common import (
21-
is_bool,
22-
is_iterator,
23-
)
20+
from pandas.core.dtypes.common import is_bool
2421
from pandas.core.dtypes.concat import concat_compat
2522
from pandas.core.dtypes.generic import (
2623
ABCDataFrame,
@@ -423,11 +420,12 @@ def __init__(
423420
self.ignore_index = ignore_index
424421
self.verify_integrity = verify_integrity
425422

426-
objs, keys = self._clean_keys_and_objs(objs, keys)
423+
objs, keys, ndims = _clean_keys_and_objs(objs, keys)
427424

428-
# figure out what our result ndim is going to be
429-
ndims = self._get_ndims(objs)
430-
sample, objs = self._get_sample_object(objs, ndims, keys, names, levels)
425+
# select an object to be our result reference
426+
sample, objs = _get_sample_object(
427+
objs, ndims, keys, names, levels, self.intersect
428+
)
431429

432430
# Standardize axis parameter to int
433431
if sample.ndim == 1:
@@ -458,100 +456,6 @@ def __init__(
458456
self.names = names or getattr(keys, "names", None)
459457
self.levels = levels
460458

461-
def _get_ndims(self, objs: list[Series | DataFrame]) -> set[int]:
462-
# figure out what our result ndim is going to be
463-
ndims = set()
464-
for obj in objs:
465-
if not isinstance(obj, (ABCSeries, ABCDataFrame)):
466-
msg = (
467-
f"cannot concatenate object of type '{type(obj)}'; "
468-
"only Series and DataFrame objs are valid"
469-
)
470-
raise TypeError(msg)
471-
472-
ndims.add(obj.ndim)
473-
return ndims
474-
475-
def _clean_keys_and_objs(
476-
self,
477-
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
478-
keys,
479-
) -> tuple[list[Series | DataFrame], Index | None]:
480-
if isinstance(objs, abc.Mapping):
481-
if keys is None:
482-
keys = list(objs.keys())
483-
objs_list = [objs[k] for k in keys]
484-
else:
485-
objs_list = list(objs)
486-
487-
if len(objs_list) == 0:
488-
raise ValueError("No objects to concatenate")
489-
490-
if keys is None:
491-
objs_list = list(com.not_none(*objs_list))
492-
else:
493-
# GH#1649
494-
key_indices = []
495-
clean_objs = []
496-
if is_iterator(keys):
497-
keys = list(keys)
498-
if len(keys) != len(objs_list):
499-
# GH#43485
500-
raise ValueError(
501-
f"The length of the keys ({len(keys)}) must match "
502-
f"the length of the objects to concatenate ({len(objs_list)})"
503-
)
504-
for i, obj in enumerate(objs_list):
505-
if obj is not None:
506-
key_indices.append(i)
507-
clean_objs.append(obj)
508-
objs_list = clean_objs
509-
510-
if not isinstance(keys, Index):
511-
keys = Index(keys)
512-
513-
if len(key_indices) < len(keys):
514-
keys = keys.take(key_indices)
515-
516-
if len(objs_list) == 0:
517-
raise ValueError("All objects passed were None")
518-
519-
return objs_list, keys
520-
521-
def _get_sample_object(
522-
self,
523-
objs: list[Series | DataFrame],
524-
ndims: set[int],
525-
keys,
526-
names,
527-
levels,
528-
) -> tuple[Series | DataFrame, list[Series | DataFrame]]:
529-
# get the sample
530-
# want the highest ndim that we have, and must be non-empty
531-
# unless all objs are empty
532-
sample: Series | DataFrame | None = None
533-
if len(ndims) > 1:
534-
max_ndim = max(ndims)
535-
for obj in objs:
536-
if obj.ndim == max_ndim and np.sum(obj.shape):
537-
sample = obj
538-
break
539-
540-
else:
541-
# filter out the empties if we have not multi-index possibilities
542-
# note to keep empty Series as it affect to result columns / name
543-
non_empties = [obj for obj in objs if sum(obj.shape) > 0 or obj.ndim == 1]
544-
545-
if len(non_empties) and (
546-
keys is None and names is None and levels is None and not self.intersect
547-
):
548-
objs = non_empties
549-
sample = objs[0]
550-
551-
if sample is None:
552-
sample = objs[0]
553-
return sample, objs
554-
555459
def _sanitize_mixed_ndim(
556460
self,
557461
objs: list[Series | DataFrame],
@@ -664,29 +568,24 @@ def get_result(self):
664568
out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
665569
return out.__finalize__(self, method="concat")
666570

667-
def _get_result_dim(self) -> int:
668-
if self._is_series and self.bm_axis == 1:
669-
return 2
670-
else:
671-
return self.objs[0].ndim
672-
673571
@cache_readonly
674572
def new_axes(self) -> list[Index]:
675-
ndim = self._get_result_dim()
573+
if self._is_series and self.bm_axis == 1:
574+
ndim = 2
575+
else:
576+
ndim = self.objs[0].ndim
676577
return [
677-
self._get_concat_axis if i == self.bm_axis else self._get_comb_axis(i)
578+
self._get_concat_axis
579+
if i == self.bm_axis
580+
else get_objs_combined_axis(
581+
self.objs,
582+
axis=self.objs[0]._get_block_manager_axis(i),
583+
intersect=self.intersect,
584+
sort=self.sort,
585+
)
678586
for i in range(ndim)
679587
]
680588

681-
def _get_comb_axis(self, i: AxisInt) -> Index:
682-
data_axis = self.objs[0]._get_block_manager_axis(i)
683-
return get_objs_combined_axis(
684-
self.objs,
685-
axis=data_axis,
686-
intersect=self.intersect,
687-
sort=self.sort,
688-
)
689-
690589
@cache_readonly
691590
def _get_concat_axis(self) -> Index:
692591
"""
@@ -747,6 +646,98 @@ def _maybe_check_integrity(self, concat_index: Index) -> None:
747646
raise ValueError(f"Indexes have overlapping values: {overlap}")
748647

749648

649+
def _clean_keys_and_objs(
650+
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
651+
keys,
652+
) -> tuple[list[Series | DataFrame], Index | None, set[int]]:
653+
"""
654+
Returns
655+
-------
656+
clean_objs : list[Series | DataFrame]
657+
LIst of DataFrame and Series with Nones removed.
658+
keys : Index | None
659+
None if keys was None
660+
Index if objs was a Mapping or keys was not None. Filtered where objs was None.
661+
ndim : set[int]
662+
Unique .ndim attribute of obj encountered.
663+
"""
664+
if isinstance(objs, abc.Mapping):
665+
if keys is None:
666+
keys = objs.keys()
667+
objs_list = [objs[k] for k in keys]
668+
else:
669+
objs_list = list(objs)
670+
671+
if len(objs_list) == 0:
672+
raise ValueError("No objects to concatenate")
673+
674+
if keys is not None:
675+
if not isinstance(keys, Index):
676+
keys = Index(keys)
677+
if len(keys) != len(objs_list):
678+
# GH#43485
679+
raise ValueError(
680+
f"The length of the keys ({len(keys)}) must match "
681+
f"the length of the objects to concatenate ({len(objs_list)})"
682+
)
683+
684+
# GH#1649
685+
key_indices = []
686+
clean_objs = []
687+
ndims = set()
688+
for i, obj in enumerate(objs_list):
689+
if obj is None:
690+
continue
691+
elif isinstance(obj, (ABCSeries, ABCDataFrame)):
692+
key_indices.append(i)
693+
clean_objs.append(obj)
694+
ndims.add(obj.ndim)
695+
else:
696+
msg = (
697+
f"cannot concatenate object of type '{type(obj)}'; "
698+
"only Series and DataFrame objs are valid"
699+
)
700+
raise TypeError(msg)
701+
702+
if keys is not None and len(key_indices) < len(keys):
703+
keys = keys.take(key_indices)
704+
705+
if len(clean_objs) == 0:
706+
raise ValueError("All objects passed were None")
707+
708+
return clean_objs, keys, ndims
709+
710+
711+
def _get_sample_object(
712+
objs: list[Series | DataFrame],
713+
ndims: set[int],
714+
keys,
715+
names,
716+
levels,
717+
intersect: bool,
718+
) -> tuple[Series | DataFrame, list[Series | DataFrame]]:
719+
# get the sample
720+
# want the highest ndim that we have, and must be non-empty
721+
# unless all objs are empty
722+
if len(ndims) > 1:
723+
max_ndim = max(ndims)
724+
for obj in objs:
725+
if obj.ndim == max_ndim and sum(obj.shape): # type: ignore[arg-type]
726+
return obj, objs
727+
elif keys is None and names is None and levels is None and not intersect:
728+
# filter out the empties if we have not multi-index possibilities
729+
# note to keep empty Series as it affect to result columns / name
730+
if ndims.pop() == 2:
731+
non_empties = [obj for obj in objs if sum(obj.shape)]
732+
else:
733+
non_empties = objs
734+
735+
if len(non_empties):
736+
return non_empties[0], non_empties
737+
738+
return objs[0], objs
739+
740+
750741
def _concat_indexes(indexes) -> Index:
751742
return indexes[0].append(indexes[1:])
752743

0 commit comments

Comments
 (0)