3
3
from collections import defaultdict
4
4
import copy
5
5
import itertools
6
- from typing import TYPE_CHECKING , Any , Dict , List , Sequence , Tuple , cast
6
+ from typing import TYPE_CHECKING , Dict , List , Sequence , cast
7
7
8
8
import numpy as np
9
9
10
- from pandas ._libs import NaT , internals as libinternals
10
+ from pandas ._libs import internals as libinternals
11
11
from pandas ._typing import ArrayLike , DtypeObj , Manager , Shape
12
12
from pandas .util ._decorators import cache_readonly
13
13
@@ -338,7 +338,10 @@ def _concatenate_join_units(
338
338
# Concatenating join units along ax0 is handled in _merge_blocks.
339
339
raise AssertionError ("Concatenating join units along axis0" )
340
340
341
- empty_dtype , upcasted_na = _get_empty_dtype_and_na (join_units )
341
+ empty_dtype = _get_empty_dtype (join_units )
342
+
343
+ has_none_blocks = any (unit .block is None for unit in join_units )
344
+ upcasted_na = _dtype_to_na_value (empty_dtype , has_none_blocks )
342
345
343
346
to_concat = [
344
347
ju .get_reindexed_values (empty_dtype = empty_dtype , upcasted_na = upcasted_na )
@@ -375,7 +378,28 @@ def _concatenate_join_units(
375
378
return concat_values
376
379
377
380
378
- def _get_empty_dtype_and_na (join_units : Sequence [JoinUnit ]) -> Tuple [DtypeObj , Any ]:
381
+ def _dtype_to_na_value (dtype : DtypeObj , has_none_blocks : bool ):
382
+ """
383
+ Find the NA value to go with this dtype.
384
+ """
385
+ if is_extension_array_dtype (dtype ):
386
+ return dtype .na_value
387
+ elif dtype .kind in ["m" , "M" ]:
388
+ return dtype .type ("NaT" )
389
+ elif dtype .kind in ["f" , "c" ]:
390
+ return dtype .type ("NaN" )
391
+ elif dtype .kind == "b" :
392
+ return None
393
+ elif dtype .kind in ["i" , "u" ]:
394
+ if not has_none_blocks :
395
+ return None
396
+ return np .nan
397
+ elif dtype .kind == "O" :
398
+ return np .nan
399
+ raise NotImplementedError
400
+
401
+
402
+ def _get_empty_dtype (join_units : Sequence [JoinUnit ]) -> DtypeObj :
379
403
"""
380
404
Return dtype and N/A values to use when concatenating specified units.
381
405
@@ -384,30 +408,19 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
384
408
Returns
385
409
-------
386
410
dtype
387
- na
388
411
"""
389
412
if len (join_units ) == 1 :
390
413
blk = join_units [0 ].block
391
414
if blk is None :
392
- return np .dtype (np .float64 ), np . nan
415
+ return np .dtype (np .float64 )
393
416
394
417
if _is_uniform_reindex (join_units ):
395
418
# FIXME: integrate property
396
419
empty_dtype = join_units [0 ].block .dtype
397
- if is_extension_array_dtype (empty_dtype ):
398
- # for dt64tz we need this to get NaT instead of np.datetime64("NaT")
399
- upcasted_na = empty_dtype .na_value
400
- else :
401
- upcasted_na = join_units [0 ].block .fill_value
402
- return empty_dtype , upcasted_na
403
-
404
- has_none_blocks = False
405
- dtypes = [None ] * len (join_units )
406
- for i , unit in enumerate (join_units ):
407
- if unit .block is None :
408
- has_none_blocks = True
409
- else :
410
- dtypes [i ] = unit .dtype
420
+ return empty_dtype
421
+
422
+ has_none_blocks = any (unit .block is None for unit in join_units )
423
+ dtypes = [None if unit .block is None else unit .dtype for unit in join_units ]
411
424
412
425
filtered_dtypes = [
413
426
unit .dtype for unit in join_units if unit .block is not None and not unit .is_na
@@ -419,42 +432,42 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
419
432
upcast_classes = _get_upcast_classes (join_units , dtypes )
420
433
421
434
if is_extension_array_dtype (dtype_alt ):
422
- return dtype_alt , dtype_alt . na_value
435
+ return dtype_alt
423
436
elif dtype_alt == object :
424
- return dtype_alt , np . nan
437
+ return dtype_alt
425
438
426
439
# TODO: de-duplicate with maybe_promote?
427
440
# create the result
428
441
if "extension" in upcast_classes :
429
- return np .dtype ("object" ), np . nan
442
+ return np .dtype ("object" )
430
443
elif "bool" in upcast_classes :
431
444
if has_none_blocks :
432
- return np .dtype (np .object_ ), np . nan
445
+ return np .dtype (np .object_ )
433
446
else :
434
- return np .dtype (np .bool_ ), None
447
+ return np .dtype (np .bool_ )
435
448
elif "datetimetz" in upcast_classes :
436
449
# GH-25014. We use NaT instead of iNaT, since this eventually
437
450
# ends up in DatetimeArray.take, which does not allow iNaT.
438
451
dtype = upcast_classes ["datetimetz" ]
439
- return dtype [0 ], NaT
452
+ return dtype [0 ]
440
453
elif "datetime" in upcast_classes :
441
- return np .dtype ("M8[ns]" ), np . datetime64 ( "NaT" , "ns" )
454
+ return np .dtype ("M8[ns]" )
442
455
elif "timedelta" in upcast_classes :
443
- return np .dtype ("m8[ns]" ), np . timedelta64 ( "NaT" , "ns" )
456
+ return np .dtype ("m8[ns]" )
444
457
else :
445
458
try :
446
459
common_dtype = np .find_common_type (upcast_classes , [])
447
460
except TypeError :
448
461
# At least one is an ExtensionArray
449
- return np .dtype (np .object_ ), np . nan
462
+ return np .dtype (np .object_ )
450
463
else :
451
464
if is_float_dtype (common_dtype ):
452
- return common_dtype , common_dtype . type ( np . nan )
465
+ return common_dtype
453
466
elif is_numeric_dtype (common_dtype ):
454
467
if has_none_blocks :
455
- return np .dtype (np .float64 ), np . nan
468
+ return np .dtype (np .float64 )
456
469
else :
457
- return common_dtype , None
470
+ return common_dtype
458
471
459
472
msg = "invalid dtype determination in get_concat_dtype"
460
473
raise AssertionError (msg )
0 commit comments