From 9396388cd35b07e6aea621414ae193fab36c6ae4 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 19 Mar 2024 12:40:52 -0700 Subject: [PATCH 1/3] REF: Clean up concat statefullness and validation --- pandas/core/reshape/concat.py | 233 ++++++++++++++++------------------ 1 file changed, 111 insertions(+), 122 deletions(-) diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index 1f0fe0542a0c0..4acbad4661c36 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -17,10 +17,7 @@ from pandas.util._decorators import cache_readonly -from pandas.core.dtypes.common import ( - is_bool, - is_iterator, -) +from pandas.core.dtypes.common import is_bool from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.generic import ( ABCDataFrame, @@ -423,17 +420,16 @@ def __init__( self.ignore_index = ignore_index self.verify_integrity = verify_integrity - objs, keys = self._clean_keys_and_objs(objs, keys) + objs, keys, ndims = _clean_keys_and_objs(objs, keys) - # figure out what our result ndim is going to be - ndims = self._get_ndims(objs) - sample, objs = self._get_sample_object(objs, ndims, keys, names, levels) + # select an object to be our result reference + sample, objs = _get_sample_object( + objs, ndims, keys, names, levels, self.intersect + ) # Standardize axis parameter to int if sample.ndim == 1: - from pandas import DataFrame - - axis = DataFrame._get_axis_number(axis) + axis = sample._constructor_expanddim._get_axis_number(axis) self._is_frame = False self._is_series = True else: @@ -458,100 +454,6 @@ def __init__( self.names = names or getattr(keys, "names", None) self.levels = levels - def _get_ndims(self, objs: list[Series | DataFrame]) -> set[int]: - # figure out what our result ndim is going to be - ndims = set() - for obj in objs: - if not isinstance(obj, (ABCSeries, ABCDataFrame)): - msg = ( - f"cannot concatenate object of type '{type(obj)}'; " - "only Series and DataFrame objs are valid" - ) - raise TypeError(msg) - - ndims.add(obj.ndim) - return ndims - - def _clean_keys_and_objs( - self, - objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame], - keys, - ) -> tuple[list[Series | DataFrame], Index | None]: - if isinstance(objs, abc.Mapping): - if keys is None: - keys = list(objs.keys()) - objs_list = [objs[k] for k in keys] - else: - objs_list = list(objs) - - if len(objs_list) == 0: - raise ValueError("No objects to concatenate") - - if keys is None: - objs_list = list(com.not_none(*objs_list)) - else: - # GH#1649 - key_indices = [] - clean_objs = [] - if is_iterator(keys): - keys = list(keys) - if len(keys) != len(objs_list): - # GH#43485 - raise ValueError( - f"The length of the keys ({len(keys)}) must match " - f"the length of the objects to concatenate ({len(objs_list)})" - ) - for i, obj in enumerate(objs_list): - if obj is not None: - key_indices.append(i) - clean_objs.append(obj) - objs_list = clean_objs - - if not isinstance(keys, Index): - keys = Index(keys) - - if len(key_indices) < len(keys): - keys = keys.take(key_indices) - - if len(objs_list) == 0: - raise ValueError("All objects passed were None") - - return objs_list, keys - - def _get_sample_object( - self, - objs: list[Series | DataFrame], - ndims: set[int], - keys, - names, - levels, - ) -> tuple[Series | DataFrame, list[Series | DataFrame]]: - # get the sample - # want the highest ndim that we have, and must be non-empty - # unless all objs are empty - sample: Series | DataFrame | None = None - if len(ndims) > 1: - max_ndim = max(ndims) - for obj in objs: - if obj.ndim == max_ndim and np.sum(obj.shape): - sample = obj - break - - else: - # filter out the empties if we have not multi-index possibilities - # note to keep empty Series as it affect to result columns / name - non_empties = [obj for obj in objs if sum(obj.shape) > 0 or obj.ndim == 1] - - if len(non_empties) and ( - keys is None and names is None and levels is None and not self.intersect - ): - objs = non_empties - sample = objs[0] - - if sample is None: - sample = objs[0] - return sample, objs - def _sanitize_mixed_ndim( self, objs: list[Series | DataFrame], @@ -664,29 +566,24 @@ def get_result(self): out = sample._constructor_from_mgr(new_data, axes=new_data.axes) return out.__finalize__(self, method="concat") - def _get_result_dim(self) -> int: - if self._is_series and self.bm_axis == 1: - return 2 - else: - return self.objs[0].ndim - @cache_readonly def new_axes(self) -> list[Index]: - ndim = self._get_result_dim() + if self._is_series and self.bm_axis == 1: + ndim = 2 + else: + ndim = self.objs[0].ndim return [ - self._get_concat_axis if i == self.bm_axis else self._get_comb_axis(i) + self._get_concat_axis + if i == self.bm_axis + else get_objs_combined_axis( + self.objs, + axis=self.objs[0]._get_block_manager_axis(i), + intersect=self.intersect, + sort=self.sort, + ) for i in range(ndim) ] - def _get_comb_axis(self, i: AxisInt) -> Index: - data_axis = self.objs[0]._get_block_manager_axis(i) - return get_objs_combined_axis( - self.objs, - axis=data_axis, - intersect=self.intersect, - sort=self.sort, - ) - @cache_readonly def _get_concat_axis(self) -> Index: """ @@ -747,6 +644,98 @@ def _maybe_check_integrity(self, concat_index: Index) -> None: raise ValueError(f"Indexes have overlapping values: {overlap}") +def _clean_keys_and_objs( + objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame], + keys, +) -> tuple[list[Series | DataFrame], Index | None, set[int]]: + """ + Returns + ------- + clean_objs : list[Series | DataFrame] + LIst of DataFrame and Series with Nones removed. + keys : Index | None + None if keys was None + Index if objs was a Mapping or keys was not None. Filtered where objs was None. + ndim : set[int] + Unique .ndim attribute of obj encountered. + """ + if isinstance(objs, abc.Mapping): + if keys is None: + keys = objs.keys() + objs_list = [objs[k] for k in keys] + else: + objs_list = list(objs) + + if len(objs_list) == 0: + raise ValueError("No objects to concatenate") + + if keys is not None: + if not isinstance(keys, Index): + keys = Index(keys) + if len(keys) != len(objs_list): + # GH#43485 + raise ValueError( + f"The length of the keys ({len(keys)}) must match " + f"the length of the objects to concatenate ({len(objs_list)})" + ) + + # GH#1649 + key_indices = [] + clean_objs = [] + ndims = set() + for i, obj in enumerate(objs_list): + if obj is None: + continue + elif isinstance(obj, (ABCSeries, ABCDataFrame)): + key_indices.append(i) + clean_objs.append(obj) + ndims.add(obj.ndim) + else: + msg = ( + f"cannot concatenate object of type '{type(obj)}'; " + "only Series and DataFrame objs are valid" + ) + raise TypeError(msg) + + if keys is not None and len(key_indices) < len(keys): + keys = keys.take(key_indices) + + if len(clean_objs) == 0: + raise ValueError("All objects passed were None") + + return clean_objs, keys, ndims + + +def _get_sample_object( + objs: list[Series | DataFrame], + ndims: set[int], + keys, + names, + levels, + intersect: bool, +) -> tuple[Series | DataFrame, list[Series | DataFrame]]: + # get the sample + # want the highest ndim that we have, and must be non-empty + # unless all objs are empty + if len(ndims) > 1: + max_ndim = max(ndims) + for obj in objs: + if obj.ndim == max_ndim and sum(obj.shape): + return obj, objs + elif keys is None and names is None and levels is None and not intersect: + # filter out the empties if we have not multi-index possibilities + # note to keep empty Series as it affect to result columns / name + if ndims.pop() == 2: + non_empties = [obj for obj in objs if sum(obj.shape)] + else: + non_empties = objs + + if len(non_empties): + return non_empties[0], non_empties + + return objs[0], objs + + def _concat_indexes(indexes) -> Index: return indexes[0].append(indexes[1:]) From 3b0c9fbf64a36b0338a111de0cd71a3fb504582c Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:56:47 -0700 Subject: [PATCH 2/3] Use DataFrame again --- pandas/core/reshape/concat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index 4acbad4661c36..147670ed788da 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -429,7 +429,9 @@ def __init__( # Standardize axis parameter to int if sample.ndim == 1: - axis = sample._constructor_expanddim._get_axis_number(axis) + from pandas import DataFrame + + axis = DataFrame._get_axis_number(axis) self._is_frame = False self._is_series = True else: From 33610e92f24c8800588092ef92d3c4a04a5d3b7a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 20 Mar 2024 11:58:01 -0700 Subject: [PATCH 3/3] Ignore false positive mypy? --- pandas/core/reshape/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index 147670ed788da..35a08e0167924 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -722,7 +722,7 @@ def _get_sample_object( if len(ndims) > 1: max_ndim = max(ndims) for obj in objs: - if obj.ndim == max_ndim and sum(obj.shape): + if obj.ndim == max_ndim and sum(obj.shape): # type: ignore[arg-type] return obj, objs elif keys is None and names is None and levels is None and not intersect: # filter out the empties if we have not multi-index possibilities