diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index dd005752a4832..07904339b93df 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -1,7 +1,7 @@ """ Utility functions related to concat. """ -from typing import cast +from typing import Set, cast import numpy as np @@ -9,15 +9,10 @@ from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( - is_bool_dtype, is_categorical_dtype, - is_datetime64_dtype, - is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, - is_object_dtype, is_sparse, - is_timedelta64_dtype, ) from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries @@ -26,7 +21,7 @@ from pandas.core.construction import array -def get_dtype_kinds(l): +def _get_dtype_kinds(l) -> Set[str]: """ Parameters ---------- @@ -34,34 +29,30 @@ def get_dtype_kinds(l): Returns ------- - a set of kinds that exist in this list of arrays + set[str] + A set of kinds that exist in this list of arrays. """ - typs = set() + typs: Set[str] = set() for arr in l: + # Note: we use dtype.kind checks because they are much more performant + # than is_foo_dtype dtype = arr.dtype - if is_categorical_dtype(dtype): - typ = "category" - elif is_sparse(dtype): - typ = "sparse" + if not isinstance(dtype, np.dtype): + # ExtensionDtype so we get + # e.g. "categorical", "datetime64[ns, US/Central]", "Sparse[itn64, 0]" + typ = str(dtype) elif isinstance(arr, ABCRangeIndex): typ = "range" - elif is_datetime64tz_dtype(dtype): - # if to_concat contains different tz, - # the result must be object dtype - typ = str(dtype) - elif is_datetime64_dtype(dtype): + elif dtype.kind == "M": typ = "datetime" - elif is_timedelta64_dtype(dtype): + elif dtype.kind == "m": typ = "timedelta" - elif is_object_dtype(dtype): - typ = "object" - elif is_bool_dtype(dtype): - typ = "bool" - elif is_extension_array_dtype(dtype): - typ = str(dtype) + elif dtype.kind in ["O", "b"]: + typ = str(dtype) # i.e. "object", "bool" else: typ = dtype.kind + typs.add(typ) return typs @@ -140,7 +131,7 @@ def is_nonempty(x) -> bool: if non_empties and axis == 0: to_concat = non_empties - typs = get_dtype_kinds(to_concat) + typs = _get_dtype_kinds(to_concat) _contains_datetime = any(typ.startswith("datetime") for typ in typs) all_empty = not len(non_empties) @@ -161,13 +152,13 @@ def is_nonempty(x) -> bool: return np.concatenate(to_concat) elif _contains_datetime or "timedelta" in typs: - return concat_datetime(to_concat, axis=axis, typs=typs) + return _concat_datetime(to_concat, axis=axis, typs=typs) elif all_empty: # we have all empties, but may need to coerce the result dtype to # object if we have non-numeric type operands (numpy would otherwise # cast this to float) - typs = get_dtype_kinds(to_concat) + typs = _get_dtype_kinds(to_concat) if len(typs) != 1: if not len(typs - {"i", "u", "f"}) or not len(typs - {"bool", "i", "u"}): @@ -361,7 +352,7 @@ def _concatenate_2d(to_concat, axis: int): return np.concatenate(to_concat, axis=axis) -def concat_datetime(to_concat, axis=0, typs=None): +def _concat_datetime(to_concat, axis=0, typs=None): """ provide concatenation of an datetimelike array of arrays each of which is a single M8[ns], datetime64[ns, tz] or m8[ns] dtype @@ -377,7 +368,7 @@ def concat_datetime(to_concat, axis=0, typs=None): a single array, preserving the combined dtypes """ if typs is None: - typs = get_dtype_kinds(to_concat) + typs = _get_dtype_kinds(to_concat) to_concat = [_wrap_datetimelike(x) for x in to_concat] single_dtype = len({x.dtype for x in to_concat}) == 1 diff --git a/pandas/tests/dtypes/test_concat.py b/pandas/tests/dtypes/test_concat.py index 5a9ad732792ea..53d53e35c6eb5 100644 --- a/pandas/tests/dtypes/test_concat.py +++ b/pandas/tests/dtypes/test_concat.py @@ -44,7 +44,7 @@ ) def test_get_dtype_kinds(index_or_series, to_concat, expected): to_concat_klass = [index_or_series(c) for c in to_concat] - result = _concat.get_dtype_kinds(to_concat_klass) + result = _concat._get_dtype_kinds(to_concat_klass) assert result == set(expected) @@ -76,7 +76,7 @@ def test_get_dtype_kinds(index_or_series, to_concat, expected): ], ) def test_get_dtype_kinds_period(to_concat, expected): - result = _concat.get_dtype_kinds(to_concat) + result = _concat._get_dtype_kinds(to_concat) assert result == set(expected)