diff --git a/pandas/core/internals/concat.py b/pandas/core/internals/concat.py index 214fe0c1de9f4..3dcfa85ed5c08 100644 --- a/pandas/core/internals/concat.py +++ b/pandas/core/internals/concat.py @@ -3,11 +3,11 @@ from collections import defaultdict import copy import itertools -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Sequence, cast import numpy as np -from pandas._libs import NaT, internals as libinternals +from pandas._libs import internals as libinternals from pandas._typing import ArrayLike, DtypeObj, Manager, Shape from pandas.util._decorators import cache_readonly @@ -338,7 +338,10 @@ def _concatenate_join_units( # Concatenating join units along ax0 is handled in _merge_blocks. raise AssertionError("Concatenating join units along axis0") - empty_dtype, upcasted_na = _get_empty_dtype_and_na(join_units) + empty_dtype = _get_empty_dtype(join_units) + + has_none_blocks = any(unit.block is None for unit in join_units) + upcasted_na = _dtype_to_na_value(empty_dtype, has_none_blocks) to_concat = [ ju.get_reindexed_values(empty_dtype=empty_dtype, upcasted_na=upcasted_na) @@ -375,7 +378,28 @@ def _concatenate_join_units( return concat_values -def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, Any]: +def _dtype_to_na_value(dtype: DtypeObj, has_none_blocks: bool): + """ + Find the NA value to go with this dtype. + """ + if is_extension_array_dtype(dtype): + return dtype.na_value + elif dtype.kind in ["m", "M"]: + return dtype.type("NaT") + elif dtype.kind in ["f", "c"]: + return dtype.type("NaN") + elif dtype.kind == "b": + return None + elif dtype.kind in ["i", "u"]: + if not has_none_blocks: + return None + return np.nan + elif dtype.kind == "O": + return np.nan + raise NotImplementedError + + +def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj: """ Return dtype and N/A values to use when concatenating specified units. @@ -384,30 +408,19 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A Returns ------- dtype - na """ if len(join_units) == 1: blk = join_units[0].block if blk is None: - return np.dtype(np.float64), np.nan + return np.dtype(np.float64) if _is_uniform_reindex(join_units): # FIXME: integrate property empty_dtype = join_units[0].block.dtype - if is_extension_array_dtype(empty_dtype): - # for dt64tz we need this to get NaT instead of np.datetime64("NaT") - upcasted_na = empty_dtype.na_value - else: - upcasted_na = join_units[0].block.fill_value - return empty_dtype, upcasted_na - - has_none_blocks = False - dtypes = [None] * len(join_units) - for i, unit in enumerate(join_units): - if unit.block is None: - has_none_blocks = True - else: - dtypes[i] = unit.dtype + return empty_dtype + + has_none_blocks = any(unit.block is None for unit in join_units) + dtypes = [None if unit.block is None else unit.dtype for unit in join_units] filtered_dtypes = [ unit.dtype for unit in join_units if unit.block is not None and not unit.is_na @@ -419,42 +432,42 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A upcast_classes = _get_upcast_classes(join_units, dtypes) if is_extension_array_dtype(dtype_alt): - return dtype_alt, dtype_alt.na_value + return dtype_alt elif dtype_alt == object: - return dtype_alt, np.nan + return dtype_alt # TODO: de-duplicate with maybe_promote? # create the result if "extension" in upcast_classes: - return np.dtype("object"), np.nan + return np.dtype("object") elif "bool" in upcast_classes: if has_none_blocks: - return np.dtype(np.object_), np.nan + return np.dtype(np.object_) else: - return np.dtype(np.bool_), None + return np.dtype(np.bool_) elif "datetimetz" in upcast_classes: # GH-25014. We use NaT instead of iNaT, since this eventually # ends up in DatetimeArray.take, which does not allow iNaT. dtype = upcast_classes["datetimetz"] - return dtype[0], NaT + return dtype[0] elif "datetime" in upcast_classes: - return np.dtype("M8[ns]"), np.datetime64("NaT", "ns") + return np.dtype("M8[ns]") elif "timedelta" in upcast_classes: - return np.dtype("m8[ns]"), np.timedelta64("NaT", "ns") + return np.dtype("m8[ns]") else: try: common_dtype = np.find_common_type(upcast_classes, []) except TypeError: # At least one is an ExtensionArray - return np.dtype(np.object_), np.nan + return np.dtype(np.object_) else: if is_float_dtype(common_dtype): - return common_dtype, common_dtype.type(np.nan) + return common_dtype elif is_numeric_dtype(common_dtype): if has_none_blocks: - return np.dtype(np.float64), np.nan + return np.dtype(np.float64) else: - return common_dtype, None + return common_dtype msg = "invalid dtype determination in get_concat_dtype" raise AssertionError(msg)