Skip to content

REF: simplify _get_empty_dtype_and_na #39453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 45 additions & 32 deletions pandas/core/internals/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from collections import defaultdict
import copy
import itertools
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Sequence, cast

import numpy as np

from pandas._libs import NaT, internals as libinternals
from pandas._libs import internals as libinternals
from pandas._typing import ArrayLike, DtypeObj, Manager, Shape
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -338,7 +338,10 @@ def _concatenate_join_units(
# Concatenating join units along ax0 is handled in _merge_blocks.
raise AssertionError("Concatenating join units along axis0")

empty_dtype, upcasted_na = _get_empty_dtype_and_na(join_units)
empty_dtype = _get_empty_dtype(join_units)

has_none_blocks = any(unit.block is None for unit in join_units)
upcasted_na = _dtype_to_na_value(empty_dtype, has_none_blocks)

to_concat = [
ju.get_reindexed_values(empty_dtype=empty_dtype, upcasted_na=upcasted_na)
Expand Down Expand Up @@ -375,7 +378,28 @@ def _concatenate_join_units(
return concat_values


def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, Any]:
def _dtype_to_na_value(dtype: DtypeObj, has_none_blocks: bool):
"""
Find the NA value to go with this dtype.
"""
if is_extension_array_dtype(dtype):
return dtype.na_value
elif dtype.kind in ["m", "M"]:
return dtype.type("NaT")
elif dtype.kind in ["f", "c"]:
return dtype.type("NaN")
elif dtype.kind == "b":
return None
elif dtype.kind in ["i", "u"]:
if not has_none_blocks:
return None
return np.nan
elif dtype.kind == "O":
return np.nan
raise NotImplementedError


def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
"""
Return dtype and N/A values to use when concatenating specified units.

Expand All @@ -384,30 +408,19 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
Returns
-------
dtype
na
"""
if len(join_units) == 1:
blk = join_units[0].block
if blk is None:
return np.dtype(np.float64), np.nan
return np.dtype(np.float64)

if _is_uniform_reindex(join_units):
# FIXME: integrate property
empty_dtype = join_units[0].block.dtype
if is_extension_array_dtype(empty_dtype):
# for dt64tz we need this to get NaT instead of np.datetime64("NaT")
upcasted_na = empty_dtype.na_value
else:
upcasted_na = join_units[0].block.fill_value
return empty_dtype, upcasted_na

has_none_blocks = False
dtypes = [None] * len(join_units)
for i, unit in enumerate(join_units):
if unit.block is None:
has_none_blocks = True
else:
dtypes[i] = unit.dtype
return empty_dtype

has_none_blocks = any(unit.block is None for unit in join_units)
dtypes = [None if unit.block is None else unit.dtype for unit in join_units]

filtered_dtypes = [
unit.dtype for unit in join_units if unit.block is not None and not unit.is_na
Expand All @@ -419,42 +432,42 @@ def _get_empty_dtype_and_na(join_units: Sequence[JoinUnit]) -> Tuple[DtypeObj, A
upcast_classes = _get_upcast_classes(join_units, dtypes)

if is_extension_array_dtype(dtype_alt):
return dtype_alt, dtype_alt.na_value
return dtype_alt
elif dtype_alt == object:
return dtype_alt, np.nan
return dtype_alt

# TODO: de-duplicate with maybe_promote?
# create the result
if "extension" in upcast_classes:
return np.dtype("object"), np.nan
return np.dtype("object")
elif "bool" in upcast_classes:
if has_none_blocks:
return np.dtype(np.object_), np.nan
return np.dtype(np.object_)
else:
return np.dtype(np.bool_), None
return np.dtype(np.bool_)
elif "datetimetz" in upcast_classes:
# GH-25014. We use NaT instead of iNaT, since this eventually
# ends up in DatetimeArray.take, which does not allow iNaT.
dtype = upcast_classes["datetimetz"]
return dtype[0], NaT
return dtype[0]
elif "datetime" in upcast_classes:
return np.dtype("M8[ns]"), np.datetime64("NaT", "ns")
return np.dtype("M8[ns]")
elif "timedelta" in upcast_classes:
return np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")
return np.dtype("m8[ns]")
else:
try:
common_dtype = np.find_common_type(upcast_classes, [])
except TypeError:
# At least one is an ExtensionArray
return np.dtype(np.object_), np.nan
return np.dtype(np.object_)
else:
if is_float_dtype(common_dtype):
return common_dtype, common_dtype.type(np.nan)
return common_dtype
elif is_numeric_dtype(common_dtype):
if has_none_blocks:
return np.dtype(np.float64), np.nan
return np.dtype(np.float64)
else:
return common_dtype, None
return common_dtype

msg = "invalid dtype determination in get_concat_dtype"
raise AssertionError(msg)
Expand Down