Skip to content

Commit 5e02681

Browse files
authored
REF: simplify _get_empty_dtype_and_na (#39453)
1 parent 8f14836 commit 5e02681

File tree

1 file changed

+45
-32
lines changed

1 file changed

+45
-32
lines changed

pandas/core/internals/concat.py

+45-32
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from collections import defaultdict
44
import copy
55
import itertools
6-
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, cast
6+
from typing import TYPE_CHECKING, Dict, List, Sequence, cast
77

88
import numpy as np
99

10-
from pandas._libs import NaT, internals as libinternals
10+
from pandas._libs import internals as libinternals
1111
from pandas._typing import ArrayLike, DtypeObj, Manager, Shape
1212
from pandas.util._decorators import cache_readonly
1313

@@ -338,7 +338,10 @@ def _concatenate_join_units(
338338
# Concatenating join units along ax0 is handled in _merge_blocks.
339339
raise AssertionError("Concatenating join units along axis0")
340340

341-
empty_dtype, upcasted_na = _get_empty_dtype_and_na(join_units)
341+
empty_dtype = _get_empty_dtype(join_units)
342+
343+
has_none_blocks = any(unit.block is None for unit in join_units)
344+
upcasted_na = _dtype_to_na_value(empty_dtype, has_none_blocks)
342345

343346
to_concat = [
344347
ju.get_reindexed_values(empty_dtype=empty_dtype, upcasted_na=upcasted_na)
@@ -375,7 +378,28 @@ def _concatenate_join_units(
375378
return concat_values
376379

377380

378-
def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, Any]:
381+
def _dtype_to_na_value(dtype: DtypeObj, has_none_blocks: bool):
382+
"""
383+
Find the NA value to go with this dtype.
384+
"""
385+
if is_extension_array_dtype(dtype):
386+
return dtype.na_value
387+
elif dtype.kind in ["m", "M"]:
388+
return dtype.type("NaT")
389+
elif dtype.kind in ["f", "c"]:
390+
return dtype.type("NaN")
391+
elif dtype.kind == "b":
392+
return None
393+
elif dtype.kind in ["i", "u"]:
394+
if not has_none_blocks:
395+
return None
396+
return np.nan
397+
elif dtype.kind == "O":
398+
return np.nan
399+
raise NotImplementedError
400+
401+
402+
def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
379403
"""
380404
Return dtype and N/A values to use when concatenating specified units.
381405
@@ -384,30 +408,19 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
384408
Returns
385409
-------
386410
dtype
387-
na
388411
"""
389412
if len(join_units) == 1:
390413
blk = join_units[0].block
391414
if blk is None:
392-
return np.dtype(np.float64), np.nan
415+
return np.dtype(np.float64)
393416

394417
if _is_uniform_reindex(join_units):
395418
# FIXME: integrate property
396419
empty_dtype = join_units[0].block.dtype
397-
if is_extension_array_dtype(empty_dtype):
398-
# for dt64tz we need this to get NaT instead of np.datetime64("NaT")
399-
upcasted_na = empty_dtype.na_value
400-
else:
401-
upcasted_na = join_units[0].block.fill_value
402-
return empty_dtype, upcasted_na
403-
404-
has_none_blocks = False
405-
dtypes = [None] * len(join_units)
406-
for i, unit in enumerate(join_units):
407-
if unit.block is None:
408-
has_none_blocks = True
409-
else:
410-
dtypes[i] = unit.dtype
420+
return empty_dtype
421+
422+
has_none_blocks = any(unit.block is None for unit in join_units)
423+
dtypes = [None if unit.block is None else unit.dtype for unit in join_units]
411424

412425
filtered_dtypes = [
413426
unit.dtype for unit in join_units if unit.block is not None and not unit.is_na
@@ -419,42 +432,42 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
419432
upcast_classes = _get_upcast_classes(join_units, dtypes)
420433

421434
if is_extension_array_dtype(dtype_alt):
422-
return dtype_alt, dtype_alt.na_value
435+
return dtype_alt
423436
elif dtype_alt == object:
424-
return dtype_alt, np.nan
437+
return dtype_alt
425438

426439
# TODO: de-duplicate with maybe_promote?
427440
# create the result
428441
if "extension" in upcast_classes:
429-
return np.dtype("object"), np.nan
442+
return np.dtype("object")
430443
elif "bool" in upcast_classes:
431444
if has_none_blocks:
432-
return np.dtype(np.object_), np.nan
445+
return np.dtype(np.object_)
433446
else:
434-
return np.dtype(np.bool_), None
447+
return np.dtype(np.bool_)
435448
elif "datetimetz" in upcast_classes:
436449
# GH-25014. We use NaT instead of iNaT, since this eventually
437450
# ends up in DatetimeArray.take, which does not allow iNaT.
438451
dtype = upcast_classes["datetimetz"]
439-
return dtype[0], NaT
452+
return dtype[0]
440453
elif "datetime" in upcast_classes:
441-
return np.dtype("M8[ns]"), np.datetime64("NaT", "ns")
454+
return np.dtype("M8[ns]")
442455
elif "timedelta" in upcast_classes:
443-
return np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")
456+
return np.dtype("m8[ns]")
444457
else:
445458
try:
446459
common_dtype = np.find_common_type(upcast_classes, [])
447460
except TypeError:
448461
# At least one is an ExtensionArray
449-
return np.dtype(np.object_), np.nan
462+
return np.dtype(np.object_)
450463
else:
451464
if is_float_dtype(common_dtype):
452-
return common_dtype, common_dtype.type(np.nan)
465+
return common_dtype
453466
elif is_numeric_dtype(common_dtype):
454467
if has_none_blocks:
455-
return np.dtype(np.float64), np.nan
468+
return np.dtype(np.float64)
456469
else:
457-
return common_dtype, None
470+
return common_dtype
458471

459472
msg = "invalid dtype determination in get_concat_dtype"
460473
raise AssertionError(msg)

0 commit comments

Comments
 (0)