|
1 | 1 | from collections import defaultdict
|
2 | 2 | import copy
|
3 |
| -from typing import Dict, List |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, cast |
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 |
|
|
28 | 28 | from pandas.core.internals.blocks import make_block
|
29 | 29 | from pandas.core.internals.managers import BlockManager
|
30 | 30 |
|
| 31 | +if TYPE_CHECKING: |
| 32 | + from pandas.core.arrays.sparse.dtype import SparseDtype |
| 33 | + |
31 | 34 |
|
32 | 35 | def concatenate_block_managers(
|
33 | 36 | mgrs_indexers, axes, concat_axis: int, copy: bool
|
@@ -344,7 +347,7 @@ def _concatenate_join_units(join_units, concat_axis, copy):
|
344 | 347 | return concat_values
|
345 | 348 |
|
346 | 349 |
|
347 |
| -def _get_empty_dtype_and_na(join_units): |
| 350 | +def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, Any]: |
348 | 351 | """
|
349 | 352 | Return dtype and N/A values to use when concatenating specified units.
|
350 | 353 |
|
@@ -374,45 +377,8 @@ def _get_empty_dtype_and_na(join_units):
|
374 | 377 | else:
|
375 | 378 | dtypes[i] = unit.dtype
|
376 | 379 |
|
377 |
| - upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list) |
378 |
| - null_upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list) |
379 |
| - for dtype, unit in zip(dtypes, join_units): |
380 |
| - if dtype is None: |
381 |
| - continue |
382 |
| - |
383 |
| - if is_categorical_dtype(dtype): |
384 |
| - upcast_cls = "category" |
385 |
| - elif is_datetime64tz_dtype(dtype): |
386 |
| - upcast_cls = "datetimetz" |
387 |
| - |
388 |
| - elif is_extension_array_dtype(dtype): |
389 |
| - upcast_cls = "extension" |
390 |
| - |
391 |
| - elif issubclass(dtype.type, np.bool_): |
392 |
| - upcast_cls = "bool" |
393 |
| - elif issubclass(dtype.type, np.object_): |
394 |
| - upcast_cls = "object" |
395 |
| - elif is_datetime64_dtype(dtype): |
396 |
| - upcast_cls = "datetime" |
397 |
| - elif is_timedelta64_dtype(dtype): |
398 |
| - upcast_cls = "timedelta" |
399 |
| - elif is_sparse(dtype): |
400 |
| - upcast_cls = dtype.subtype.name |
401 |
| - elif is_float_dtype(dtype) or is_numeric_dtype(dtype): |
402 |
| - upcast_cls = dtype.name |
403 |
| - else: |
404 |
| - upcast_cls = "float" |
| 380 | + upcast_classes = _get_upcast_classes(join_units, dtypes) |
405 | 381 |
|
406 |
| - # Null blocks should not influence upcast class selection, unless there |
407 |
| - # are only null blocks, when same upcasting rules must be applied to |
408 |
| - # null upcast classes. |
409 |
| - if unit.is_na: |
410 |
| - null_upcast_classes[upcast_cls].append(dtype) |
411 |
| - else: |
412 |
| - upcast_classes[upcast_cls].append(dtype) |
413 |
| - |
414 |
| - if not upcast_classes: |
415 |
| - upcast_classes = null_upcast_classes |
416 | 382 | # TODO: de-duplicate with maybe_promote?
|
417 | 383 | # create the result
|
418 | 384 | if "extension" in upcast_classes:
|
@@ -441,23 +407,74 @@ def _get_empty_dtype_and_na(join_units):
|
441 | 407 | return np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")
|
442 | 408 | else: # pragma
|
443 | 409 | try:
|
444 |
| - g = np.find_common_type(upcast_classes, []) |
| 410 | + common_dtype = np.find_common_type(upcast_classes, []) |
445 | 411 | except TypeError:
|
446 | 412 | # At least one is an ExtensionArray
|
447 | 413 | return np.dtype(np.object_), np.nan
|
448 | 414 | else:
|
449 |
| - if is_float_dtype(g): |
450 |
| - return g, g.type(np.nan) |
451 |
| - elif is_numeric_dtype(g): |
| 415 | + if is_float_dtype(common_dtype): |
| 416 | + return common_dtype, common_dtype.type(np.nan) |
| 417 | + elif is_numeric_dtype(common_dtype): |
452 | 418 | if has_none_blocks:
|
453 | 419 | return np.dtype(np.float64), np.nan
|
454 | 420 | else:
|
455 |
| - return g, None |
| 421 | + return common_dtype, None |
456 | 422 |
|
457 | 423 | msg = "invalid dtype determination in get_concat_dtype"
|
458 | 424 | raise AssertionError(msg)
|
459 | 425 |
|
460 | 426 |
|
| 427 | +def _get_upcast_classes( |
| 428 | + join_units: Sequence[JoinUnit], |
| 429 | + dtypes: Sequence[DtypeObj], |
| 430 | +) -> Dict[str, List[DtypeObj]]: |
| 431 | + """Create mapping between upcast class names and lists of dtypes.""" |
| 432 | + upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list) |
| 433 | + null_upcast_classes: Dict[str, List[DtypeObj]] = defaultdict(list) |
| 434 | + for dtype, unit in zip(dtypes, join_units): |
| 435 | + if dtype is None: |
| 436 | + continue |
| 437 | + |
| 438 | + upcast_cls = _select_upcast_cls_from_dtype(dtype) |
| 439 | + # Null blocks should not influence upcast class selection, unless there |
| 440 | + # are only null blocks, when same upcasting rules must be applied to |
| 441 | + # null upcast classes. |
| 442 | + if unit.is_na: |
| 443 | + null_upcast_classes[upcast_cls].append(dtype) |
| 444 | + else: |
| 445 | + upcast_classes[upcast_cls].append(dtype) |
| 446 | + |
| 447 | + if not upcast_classes: |
| 448 | + upcast_classes = null_upcast_classes |
| 449 | + |
| 450 | + return upcast_classes |
| 451 | + |
| 452 | + |
| 453 | +def _select_upcast_cls_from_dtype(dtype: DtypeObj) -> str: |
| 454 | + """Select upcast class name based on dtype.""" |
| 455 | + if is_categorical_dtype(dtype): |
| 456 | + return "category" |
| 457 | + elif is_datetime64tz_dtype(dtype): |
| 458 | + return "datetimetz" |
| 459 | + elif is_extension_array_dtype(dtype): |
| 460 | + return "extension" |
| 461 | + elif issubclass(dtype.type, np.bool_): |
| 462 | + return "bool" |
| 463 | + elif issubclass(dtype.type, np.object_): |
| 464 | + return "object" |
| 465 | + elif is_datetime64_dtype(dtype): |
| 466 | + return "datetime" |
| 467 | + elif is_timedelta64_dtype(dtype): |
| 468 | + return "timedelta" |
| 469 | + elif is_sparse(dtype): |
| 470 | + dtype = cast("SparseDtype", dtype) |
| 471 | + return dtype.subtype.name |
| 472 | + elif is_float_dtype(dtype) or is_numeric_dtype(dtype): |
| 473 | + return dtype.name |
| 474 | + else: |
| 475 | + return "float" |
| 476 | + |
| 477 | + |
461 | 478 | def _is_uniform_join_units(join_units: List[JoinUnit]) -> bool:
|
462 | 479 | """
|
463 | 480 | Check if the join units consist of blocks of uniform type that can
|
|
0 commit comments