Skip to content

Commit d93d3a5

Browse files
authored
CLN: maybe_promote doesnt need to support EA dtypes (#39760)
1 parent 40ac77c commit d93d3a5

File tree

3 files changed

+14
-67
lines changed

3 files changed

+14
-67
lines changed

pandas/core/algorithms.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,12 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None)
16651665

16661666

16671667
def _take_preprocess_indexer_and_fill_value(
1668-
arr, indexer, axis, out, fill_value, allow_fill
1668+
arr: np.ndarray,
1669+
indexer: Optional[np.ndarray],
1670+
axis: int,
1671+
out: Optional[np.ndarray],
1672+
fill_value,
1673+
allow_fill: bool,
16691674
):
16701675
mask_info = None
16711676

@@ -1786,7 +1791,9 @@ def take_nd(
17861791
return out
17871792

17881793

1789-
def take_2d_multi(arr, indexer, fill_value=np.nan):
1794+
def take_2d_multi(
1795+
arr: np.ndarray, indexer: np.ndarray, fill_value=np.nan
1796+
) -> np.ndarray:
17901797
"""
17911798
Specialized Cython take which sets NaN values in one pass.
17921799
"""

pandas/core/dtypes/cast.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
conversion,
3636
iNaT,
3737
ints_to_pydatetime,
38-
tz_compare,
3938
)
4039
from pandas._typing import AnyArrayLike, ArrayLike, Dtype, DtypeObj, Scalar
4140
from pandas.util._exceptions import find_stack_level
@@ -499,13 +498,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj:
499498
return dtype
500499

501500

502-
def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
501+
def maybe_promote(dtype: np.dtype, fill_value=np.nan):
503502
"""
504503
Find the minimal dtype that can hold both the given dtype and fill_value.
505504
506505
Parameters
507506
----------
508-
dtype : np.dtype or ExtensionDtype
507+
dtype : np.dtype
509508
fill_value : scalar, default np.nan
510509
511510
Returns
@@ -567,19 +566,6 @@ def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
567566
fill_value = np.timedelta64("NaT", "ns")
568567
else:
569568
fill_value = fv.to_timedelta64()
570-
elif isinstance(dtype, DatetimeTZDtype):
571-
if isna(fill_value):
572-
fill_value = NaT
573-
elif not isinstance(fill_value, datetime):
574-
dtype = np.dtype(np.object_)
575-
elif fill_value.tzinfo is None:
576-
dtype = np.dtype(np.object_)
577-
elif not tz_compare(fill_value.tzinfo, dtype.tz):
578-
# TODO: sure we want to cast here?
579-
dtype = np.dtype(np.object_)
580-
581-
elif is_extension_array_dtype(dtype) and isna(fill_value):
582-
fill_value = dtype.na_value
583569

584570
elif is_float(fill_value):
585571
if issubclass(dtype.type, np.bool_):
@@ -634,7 +620,7 @@ def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
634620
if is_float_dtype(dtype) or is_complex_dtype(dtype):
635621
fill_value = np.nan
636622
elif is_integer_dtype(dtype):
637-
dtype = np.float64
623+
dtype = np.dtype(np.float64)
638624
fill_value = np.nan
639625
else:
640626
dtype = np.dtype(np.object_)
@@ -644,9 +630,7 @@ def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
644630
dtype = np.dtype(np.object_)
645631

646632
# in case we have a string that looked like a number
647-
if is_extension_array_dtype(dtype):
648-
pass
649-
elif issubclass(np.dtype(dtype).type, (bytes, str)):
633+
if issubclass(dtype.type, (bytes, str)):
650634
dtype = np.dtype(np.object_)
651635

652636
fill_value = _ensure_dtype_type(fill_value, dtype)

pandas/tests/dtypes/cast/test_promote.py

+1-45
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import pytest
99

10-
from pandas._libs.tslibs import NaT, tz_compare
10+
from pandas._libs.tslibs import NaT
1111

1212
from pandas.core.dtypes.cast import maybe_promote
1313
from pandas.core.dtypes.common import (
@@ -406,50 +406,6 @@ def test_maybe_promote_any_with_datetime64(
406406
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
407407

408408

409-
def test_maybe_promote_datetimetz_with_any_numpy_dtype(
410-
tz_aware_fixture, any_numpy_dtype_reduced
411-
):
412-
dtype = DatetimeTZDtype(tz=tz_aware_fixture)
413-
fill_dtype = np.dtype(any_numpy_dtype_reduced)
414-
415-
# create array of given dtype; casts "1" to correct dtype
416-
fill_value = np.array([1], dtype=fill_dtype)[0]
417-
418-
# filling datetimetz with any numpy dtype casts to object
419-
expected_dtype = np.dtype(object)
420-
exp_val_for_scalar = fill_value
421-
422-
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
423-
424-
425-
def test_maybe_promote_datetimetz_with_datetimetz(tz_aware_fixture, tz_aware_fixture2):
426-
dtype = DatetimeTZDtype(tz=tz_aware_fixture)
427-
fill_dtype = DatetimeTZDtype(tz=tz_aware_fixture2)
428-
429-
# create array of given dtype; casts "1" to correct dtype
430-
fill_value = pd.Series([10 ** 9], dtype=fill_dtype)[0]
431-
432-
# filling datetimetz with datetimetz casts to object, unless tz matches
433-
exp_val_for_scalar = fill_value
434-
if tz_compare(dtype.tz, fill_dtype.tz):
435-
expected_dtype = dtype
436-
else:
437-
expected_dtype = np.dtype(object)
438-
439-
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
440-
441-
442-
@pytest.mark.parametrize("fill_value", [None, np.nan, NaT])
443-
def test_maybe_promote_datetimetz_with_na(tz_aware_fixture, fill_value):
444-
445-
dtype = DatetimeTZDtype(tz=tz_aware_fixture)
446-
447-
expected_dtype = dtype
448-
exp_val_for_scalar = NaT
449-
450-
_check_promote(dtype, fill_value, expected_dtype, exp_val_for_scalar)
451-
452-
453409
@pytest.mark.parametrize(
454410
"fill_value",
455411
[

0 commit comments

Comments
 (0)