Skip to content

Commit 2effcd7

Browse files
committed
2 parents 863ac94 + a2f42ac commit 2effcd7

File tree

7 files changed

+113
-76
lines changed

7 files changed

+113
-76
lines changed

pandas/core/algorithms.py

+37-27
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None)
16611661
return result
16621662

16631663

1664+
def _take_preprocess_indexer_and_fill_value(
1665+
arr, indexer, axis, out, fill_value, allow_fill
1666+
):
1667+
mask_info = None
1668+
1669+
if indexer is None:
1670+
indexer = np.arange(arr.shape[axis], dtype=np.int64)
1671+
dtype, fill_value = arr.dtype, arr.dtype.type()
1672+
else:
1673+
indexer = ensure_int64(indexer, copy=False)
1674+
if not allow_fill:
1675+
dtype, fill_value = arr.dtype, arr.dtype.type()
1676+
mask_info = None, False
1677+
else:
1678+
# check for promotion based on types only (do this first because
1679+
# it's faster than computing a mask)
1680+
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1681+
if dtype != arr.dtype and (out is None or out.dtype != dtype):
1682+
# check if promotion is actually required based on indexer
1683+
mask = indexer == -1
1684+
needs_masking = mask.any()
1685+
mask_info = mask, needs_masking
1686+
if needs_masking:
1687+
if out is not None and out.dtype != dtype:
1688+
raise TypeError("Incompatible type for fill_value")
1689+
else:
1690+
# if not, then depromote, set fill_value to dummy
1691+
# (it won't be used but we don't want the cython code
1692+
# to crash when trying to cast it to dtype)
1693+
dtype, fill_value = arr.dtype, arr.dtype.type()
1694+
1695+
return indexer, dtype, fill_value, mask_info
1696+
1697+
16641698
def take_nd(
16651699
arr,
16661700
indexer,
@@ -1700,8 +1734,6 @@ def take_nd(
17001734
subarray : array-like
17011735
May be the same type as the input, or cast to an ndarray.
17021736
"""
1703-
mask_info = None
1704-
17051737
if fill_value is lib.no_default:
17061738
fill_value = na_value_for_dtype(arr.dtype, compat=False)
17071739

@@ -1712,31 +1744,9 @@ def take_nd(
17121744
arr = extract_array(arr)
17131745
arr = np.asarray(arr)
17141746

1715-
if indexer is None:
1716-
indexer = np.arange(arr.shape[axis], dtype=np.int64)
1717-
dtype, fill_value = arr.dtype, arr.dtype.type()
1718-
else:
1719-
indexer = ensure_int64(indexer, copy=False)
1720-
if not allow_fill:
1721-
dtype, fill_value = arr.dtype, arr.dtype.type()
1722-
mask_info = None, False
1723-
else:
1724-
# check for promotion based on types only (do this first because
1725-
# it's faster than computing a mask)
1726-
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
1727-
if dtype != arr.dtype and (out is None or out.dtype != dtype):
1728-
# check if promotion is actually required based on indexer
1729-
mask = indexer == -1
1730-
needs_masking = mask.any()
1731-
mask_info = mask, needs_masking
1732-
if needs_masking:
1733-
if out is not None and out.dtype != dtype:
1734-
raise TypeError("Incompatible type for fill_value")
1735-
else:
1736-
# if not, then depromote, set fill_value to dummy
1737-
# (it won't be used but we don't want the cython code
1738-
# to crash when trying to cast it to dtype)
1739-
dtype, fill_value = arr.dtype, arr.dtype.type()
1747+
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value(
1748+
arr, indexer, axis, out, fill_value, allow_fill
1749+
)
17401750

17411751
flip_order = False
17421752
if arr.ndim == 2 and arr.flags.f_contiguous:

pandas/core/dtypes/cast.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,8 @@ def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray) -> np.ndarray:
475475
# upcast (possibly), otherwise we DON't want to upcast (e.g. if we
476476
# have values, say integers, in the success portion then it's ok to not
477477
# upcast)
478-
new_dtype, _ = maybe_promote(result.dtype, np.nan)
478+
new_dtype = ensure_dtype_can_hold_na(result.dtype)
479+
479480
if new_dtype != result.dtype:
480481
result = result.astype(new_dtype, copy=True)
481482

@@ -484,7 +485,21 @@ def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray) -> np.ndarray:
484485
return result
485486

486487

487-
def maybe_promote(dtype, fill_value=np.nan):
488+
def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj:
489+
"""
490+
If we have a dtype that cannot hold NA values, find the best match that can.
491+
"""
492+
if isinstance(dtype, ExtensionDtype):
493+
# TODO: ExtensionDtype.can_hold_na?
494+
return dtype
495+
elif dtype.kind == "b":
496+
return np.dtype(object)
497+
elif dtype.kind in ["i", "u"]:
498+
return np.dtype(np.float64)
499+
return dtype
500+
501+
502+
def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
488503
"""
489504
Find the minimal dtype that can hold both the given dtype and fill_value.
490505
@@ -565,7 +580,7 @@ def maybe_promote(dtype, fill_value=np.nan):
565580
fill_value = np.timedelta64("NaT", "ns")
566581
else:
567582
fill_value = fv.to_timedelta64()
568-
elif is_datetime64tz_dtype(dtype):
583+
elif isinstance(dtype, DatetimeTZDtype):
569584
if isna(fill_value):
570585
fill_value = NaT
571586
elif not isinstance(fill_value, datetime):

pandas/core/frame.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4821,7 +4821,10 @@ def set_index(
48214821
names.extend(col.names)
48224822
elif isinstance(col, (Index, Series)):
48234823
# if Index then not MultiIndex (treated above)
4824-
arrays.append(col)
4824+
4825+
# error: Argument 1 to "append" of "list" has incompatible
4826+
# type "Union[Index, Series]"; expected "Index" [arg-type]
4827+
arrays.append(col) # type:ignore[arg-type]
48254828
names.append(col.name)
48264829
elif isinstance(col, (list, np.ndarray)):
48274830
arrays.append(col)

pandas/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5766,7 +5766,7 @@ def insert(self, loc: int, item):
57665766
idx = np.concatenate((arr[:loc], item, arr[loc:]))
57675767
return Index(idx, name=self.name)
57685768

5769-
def drop(self, labels, errors: str_t = "raise"):
5769+
def drop(self: _IndexT, labels, errors: str_t = "raise") -> _IndexT:
57705770
"""
57715771
Make new Index with passed list of labels deleted.
57725772

0 commit comments

Comments
 (0)