Skip to content

Commit faf6d3f

Browse files
authored
CLN: private funcs in concat.py (#36726)
* REF: extract func _select_upcast_cls_from_dtype * REF: extract function _get_upcast_classes * CLN: rename g -> common_dtype * TYP: type extracted functions * DOC: add docstrings to extracted methods * TYP: cast instead of ignoring mypy error * CLN: import SparseDtype only for type checking
1 parent 3cf09c9 commit faf6d3f

File tree

1 file changed

+62
-45
lines changed

1 file changed

+62
-45
lines changed

pandas/core/internals/concat.py

+62-45
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
import copy
3-
from typing import Dict, List
3+
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, cast
44

55
import numpy as np
66

@@ -28,6 +28,9 @@
2828
from pandas.core.internals.blocks import make_block
2929
from pandas.core.internals.managers import BlockManager
3030

31+
if TYPE_CHECKING:
32+
from pandas.core.arrays.sparse.dtype import SparseDtype
33+
3134

3235
def concatenate_block_managers(
3336
mgrs_indexers, axes, concat_axis: int, copy: bool
@@ -344,7 +347,7 @@ def _concatenate_join_units(join_units, concat_axis, copy):
344347
return concat_values
345348

346349

347-
def _get_empty_dtype_and_na(join_units):
350+
def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, Any]:
348351
"""
349352
Return dtype and N/A values to use when concatenating specified units.
350353
@@ -374,45 +377,8 @@ def _get_empty_dtype_and_na(join_units):
374377
else:
375378
dtypes[i] = unit.dtype
376379

377-
upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list)
378-
null_upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list)
379-
for dtype, unit in zip(dtypes, join_units):
380-
if dtype is None:
381-
continue
382-
383-
if is_categorical_dtype(dtype):
384-
upcast_cls = "category"
385-
elif is_datetime64tz_dtype(dtype):
386-
upcast_cls = "datetimetz"
387-
388-
elif is_extension_array_dtype(dtype):
389-
upcast_cls = "extension"
390-
391-
elif issubclass(dtype.type, np.bool_):
392-
upcast_cls = "bool"
393-
elif issubclass(dtype.type, np.object_):
394-
upcast_cls = "object"
395-
elif is_datetime64_dtype(dtype):
396-
upcast_cls = "datetime"
397-
elif is_timedelta64_dtype(dtype):
398-
upcast_cls = "timedelta"
399-
elif is_sparse(dtype):
400-
upcast_cls = dtype.subtype.name
401-
elif is_float_dtype(dtype) or is_numeric_dtype(dtype):
402-
upcast_cls = dtype.name
403-
else:
404-
upcast_cls = "float"
380+
upcast_classes = _get_upcast_classes(join_units, dtypes)
405381

406-
# Null blocks should not influence upcast class selection, unless there
407-
# are only null blocks, when same upcasting rules must be applied to
408-
# null upcast classes.
409-
if unit.is_na:
410-
null_upcast_classes[upcast_cls].append(dtype)
411-
else:
412-
upcast_classes[upcast_cls].append(dtype)
413-
414-
if not upcast_classes:
415-
upcast_classes = null_upcast_classes
416382
# TODO: de-duplicate with maybe_promote?
417383
# create the result
418384
if "extension" in upcast_classes:
@@ -441,23 +407,74 @@ def _get_empty_dtype_and_na(join_units):
441407
return np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")
442408
else: # pragma
443409
try:
444-
g = np.find_common_type(upcast_classes, [])
410+
common_dtype = np.find_common_type(upcast_classes, [])
445411
except TypeError:
446412
# At least one is an ExtensionArray
447413
return np.dtype(np.object_), np.nan
448414
else:
449-
if is_float_dtype(g):
450-
return g, g.type(np.nan)
451-
elif is_numeric_dtype(g):
415+
if is_float_dtype(common_dtype):
416+
return common_dtype, common_dtype.type(np.nan)
417+
elif is_numeric_dtype(common_dtype):
452418
if has_none_blocks:
453419
return np.dtype(np.float64), np.nan
454420
else:
455-
return g, None
421+
return common_dtype, None
456422

457423
msg = "invalid dtype determination in get_concat_dtype"
458424
raise AssertionError(msg)
459425

460426

427+
def _get_upcast_classes(
428+
join_units: Sequence[JoinUnit],
429+
dtypes: Sequence[DtypeObj],
430+
) -> Dict[str, List[DtypeObj]]:
431+
"""Create mapping between upcast class names and lists of dtypes."""
432+
upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list)
433+
null_upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list)
434+
for dtype, unit in zip(dtypes, join_units):
435+
if dtype is None:
436+
continue
437+
438+
upcast_cls = _select_upcast_cls_from_dtype(dtype)
439+
# Null blocks should not influence upcast class selection, unless there
440+
# are only null blocks, when same upcasting rules must be applied to
441+
# null upcast classes.
442+
if unit.is_na:
443+
null_upcast_classes[upcast_cls].append(dtype)
444+
else:
445+
upcast_classes[upcast_cls].append(dtype)
446+
447+
if not upcast_classes:
448+
upcast_classes = null_upcast_classes
449+
450+
return upcast_classes
451+
452+
453+
def _select_upcast_cls_from_dtype(dtype: DtypeObj) -> str:
454+
"""Select upcast class name based on dtype."""
455+
if is_categorical_dtype(dtype):
456+
return "category"
457+
elif is_datetime64tz_dtype(dtype):
458+
return "datetimetz"
459+
elif is_extension_array_dtype(dtype):
460+
return "extension"
461+
elif issubclass(dtype.type, np.bool_):
462+
return "bool"
463+
elif issubclass(dtype.type, np.object_):
464+
return "object"
465+
elif is_datetime64_dtype(dtype):
466+
return "datetime"
467+
elif is_timedelta64_dtype(dtype):
468+
return "timedelta"
469+
elif is_sparse(dtype):
470+
dtype = cast("SparseDtype", dtype)
471+
return dtype.subtype.name
472+
elif is_float_dtype(dtype) or is_numeric_dtype(dtype):
473+
return dtype.name
474+
else:
475+
return "float"
476+
477+
461478
def _is_uniform_join_units(join_units: List[JoinUnit]) -> bool:
462479
"""
463480
Check if the join units consist of blocks of uniform type that can

0 commit comments

Comments
 (0)