1
1
from __future__ import annotations
2
2
3
- from collections import defaultdict
4
3
import copy
5
4
import itertools
6
- from typing import TYPE_CHECKING , Dict , List , Sequence , cast
5
+ from typing import TYPE_CHECKING , Dict , List , Sequence
7
6
8
7
import numpy as np
9
8
14
13
from pandas .core .dtypes .cast import ensure_dtype_can_hold_na , find_common_type
15
14
from pandas .core .dtypes .common import (
16
15
is_categorical_dtype ,
17
- is_datetime64_dtype ,
18
16
is_datetime64tz_dtype ,
17
+ is_dtype_equal ,
19
18
is_extension_array_dtype ,
20
- is_float_dtype ,
21
- is_numeric_dtype ,
22
19
is_sparse ,
23
- is_timedelta64_dtype ,
24
20
)
25
21
from pandas .core .dtypes .concat import concat_compat
26
- from pandas .core .dtypes .missing import isna_all
22
+ from pandas .core .dtypes .missing import is_valid_na_for_dtype , isna_all
27
23
28
24
import pandas .core .algorithms as algos
29
25
from pandas .core .arrays import DatetimeArray , ExtensionArray
33
29
34
30
if TYPE_CHECKING :
35
31
from pandas import Index
36
- from pandas .core .arrays .sparse .dtype import SparseDtype
37
32
38
33
39
34
def concatenate_block_managers (
@@ -232,6 +227,29 @@ def dtype(self):
232
227
return blk .dtype
233
228
return ensure_dtype_can_hold_na (blk .dtype )
234
229
230
+ def is_valid_na_for (self , dtype : DtypeObj ) -> bool :
231
+ """
232
+ Check that we are all-NA of a type/dtype that is compatible with this dtype.
233
+ Augments `self.is_na` with an additional check of the type of NA values.
234
+ """
235
+ if not self .is_na :
236
+ return False
237
+ if self .block is None :
238
+ return True
239
+
240
+ if self .dtype == object :
241
+ values = self .block .values
242
+ return all (is_valid_na_for_dtype (x , dtype ) for x in values .ravel (order = "K" ))
243
+
244
+ if self .dtype .kind == dtype .kind == "M" and not is_dtype_equal (
245
+ self .dtype , dtype
246
+ ):
247
+ # fill_values match but we should not cast self.block.values to dtype
248
+ return False
249
+
250
+ na_value = self .block .fill_value
251
+ return is_valid_na_for_dtype (na_value , dtype )
252
+
235
253
@cache_readonly
236
254
def is_na (self ) -> bool :
237
255
if self .block is None :
@@ -262,7 +280,7 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
262
280
else :
263
281
fill_value = upcasted_na
264
282
265
- if self .is_na :
283
+ if self .is_valid_na_for ( empty_dtype ) :
266
284
blk_dtype = getattr (self .block , "dtype" , None )
267
285
268
286
if blk_dtype == np .dtype (object ):
@@ -276,10 +294,9 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
276
294
if is_datetime64tz_dtype (blk_dtype ) or is_datetime64tz_dtype (
277
295
empty_dtype
278
296
):
279
- if self .block is None :
280
- # TODO(EA2D): special case unneeded with 2D EAs
281
- i8values = np .full (self .shape [1 ], fill_value .value )
282
- return DatetimeArray (i8values , dtype = empty_dtype )
297
+ # TODO(EA2D): special case unneeded with 2D EAs
298
+ i8values = np .full (self .shape [1 ], fill_value .value )
299
+ return DatetimeArray (i8values , dtype = empty_dtype )
283
300
elif is_categorical_dtype (blk_dtype ):
284
301
pass
285
302
elif is_extension_array_dtype (blk_dtype ):
@@ -295,6 +312,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
295
312
empty_arr , allow_fill = True , fill_value = fill_value
296
313
)
297
314
else :
315
+ # NB: we should never get here with empty_dtype integer or bool;
316
+ # if we did, the missing_arr.fill would cast to gibberish
298
317
missing_arr = np .empty (self .shape , dtype = empty_dtype )
299
318
missing_arr .fill (fill_value )
300
319
return missing_arr
@@ -362,14 +381,12 @@ def _concatenate_join_units(
362
381
# concatting with at least one EA means we are concatting a single column
363
382
# the non-EA values are 2D arrays with shape (1, n)
364
383
to_concat = [t if isinstance (t , ExtensionArray ) else t [0 , :] for t in to_concat ]
365
- concat_values = concat_compat (to_concat , axis = 0 )
366
- if not isinstance (concat_values , ExtensionArray ) or (
367
- isinstance (concat_values , DatetimeArray ) and concat_values .tz is None
368
- ):
384
+ concat_values = concat_compat (to_concat , axis = 0 , ea_compat_axis = True )
385
+ if not is_extension_array_dtype (concat_values .dtype ):
369
386
# if the result of concat is not an EA but an ndarray, reshape to
370
387
# 2D to put it a non-EA Block
371
- # special case DatetimeArray, which *is* an EA, but is put in a
372
- # consolidated 2D block
388
+ # special case DatetimeArray/TimedeltaArray , which *is* an EA, but
389
+ # is put in a consolidated 2D block
373
390
concat_values = np .atleast_2d (concat_values )
374
391
else :
375
392
concat_values = concat_compat (to_concat , axis = concat_axis )
@@ -419,108 +436,17 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
419
436
return empty_dtype
420
437
421
438
has_none_blocks = any (unit .block is None for unit in join_units )
422
- dtypes = [None if unit .block is None else unit .dtype for unit in join_units ]
423
439
424
- filtered_dtypes = [
440
+ dtypes = [
425
441
unit .dtype for unit in join_units if unit .block is not None and not unit .is_na
426
442
]
427
- if not len (filtered_dtypes ):
428
- filtered_dtypes = [unit .dtype for unit in join_units if unit .block is not None ]
429
- dtype_alt = find_common_type (filtered_dtypes )
430
-
431
- upcast_classes = _get_upcast_classes (join_units , dtypes )
432
-
433
- if is_extension_array_dtype (dtype_alt ):
434
- return dtype_alt
435
- elif dtype_alt == object :
436
- return dtype_alt
437
-
438
- # TODO: de-duplicate with maybe_promote?
439
- # create the result
440
- if "extension" in upcast_classes :
441
- return np .dtype ("object" )
442
- elif "bool" in upcast_classes :
443
- if has_none_blocks :
444
- return np .dtype (np .object_ )
445
- else :
446
- return np .dtype (np .bool_ )
447
- elif "datetimetz" in upcast_classes :
448
- # GH-25014. We use NaT instead of iNaT, since this eventually
449
- # ends up in DatetimeArray.take, which does not allow iNaT.
450
- dtype = upcast_classes ["datetimetz" ]
451
- return dtype [0 ]
452
- elif "datetime" in upcast_classes :
453
- return np .dtype ("M8[ns]" )
454
- elif "timedelta" in upcast_classes :
455
- return np .dtype ("m8[ns]" )
456
- else :
457
- try :
458
- common_dtype = np .find_common_type (upcast_classes , [])
459
- except TypeError :
460
- # At least one is an ExtensionArray
461
- return np .dtype (np .object_ )
462
- else :
463
- if is_float_dtype (common_dtype ):
464
- return common_dtype
465
- elif is_numeric_dtype (common_dtype ):
466
- if has_none_blocks :
467
- return np .dtype (np .float64 )
468
- else :
469
- return common_dtype
470
-
471
- msg = "invalid dtype determination in get_concat_dtype"
472
- raise AssertionError (msg )
473
-
474
-
475
- def _get_upcast_classes (
476
- join_units : Sequence [JoinUnit ],
477
- dtypes : Sequence [DtypeObj ],
478
- ) -> Dict [str , List [DtypeObj ]]:
479
- """Create mapping between upcast class names and lists of dtypes."""
480
- upcast_classes : Dict [str , List [DtypeObj ]] = defaultdict (list )
481
- null_upcast_classes : Dict [str , List [DtypeObj ]] = defaultdict (list )
482
- for dtype , unit in zip (dtypes , join_units ):
483
- if dtype is None :
484
- continue
485
-
486
- upcast_cls = _select_upcast_cls_from_dtype (dtype )
487
- # Null blocks should not influence upcast class selection, unless there
488
- # are only null blocks, when same upcasting rules must be applied to
489
- # null upcast classes.
490
- if unit .is_na :
491
- null_upcast_classes [upcast_cls ].append (dtype )
492
- else :
493
- upcast_classes [upcast_cls ].append (dtype )
494
-
495
- if not upcast_classes :
496
- upcast_classes = null_upcast_classes
497
-
498
- return upcast_classes
499
-
500
-
501
- def _select_upcast_cls_from_dtype (dtype : DtypeObj ) -> str :
502
- """Select upcast class name based on dtype."""
503
- if is_categorical_dtype (dtype ):
504
- return "extension"
505
- elif is_datetime64tz_dtype (dtype ):
506
- return "datetimetz"
507
- elif is_extension_array_dtype (dtype ):
508
- return "extension"
509
- elif issubclass (dtype .type , np .bool_ ):
510
- return "bool"
511
- elif issubclass (dtype .type , np .object_ ):
512
- return "object"
513
- elif is_datetime64_dtype (dtype ):
514
- return "datetime"
515
- elif is_timedelta64_dtype (dtype ):
516
- return "timedelta"
517
- elif is_sparse (dtype ):
518
- dtype = cast ("SparseDtype" , dtype )
519
- return dtype .subtype .name
520
- elif is_float_dtype (dtype ) or is_numeric_dtype (dtype ):
521
- return dtype .name
522
- else :
523
- return "float"
443
+ if not len (dtypes ):
444
+ dtypes = [unit .dtype for unit in join_units if unit .block is not None ]
445
+
446
+ dtype = find_common_type (dtypes )
447
+ if has_none_blocks :
448
+ dtype = ensure_dtype_can_hold_na (dtype )
449
+ return dtype
524
450
525
451
526
452
def _is_uniform_join_units (join_units : List [JoinUnit ]) -> bool :
0 commit comments