Skip to content

Commit 4ade26b

Browse files
jbrockmendeljreback
authored andcommitted
REF: maybe_promote refactor/cleanup (#28897)
1 parent fad037e commit 4ade26b

File tree

1 file changed

+33
-40
lines changed

1 file changed

+33
-40
lines changed

pandas/core/dtypes/cast.py

+33-40
Original file line numberDiff line numberDiff line change
@@ -393,32 +393,29 @@ def maybe_promote(dtype, fill_value=np.nan):
393393

394394
elif is_float(fill_value):
395395
if issubclass(dtype.type, np.bool_):
396-
dtype = np.object_
396+
dtype = np.dtype(np.object_)
397+
397398
elif issubclass(dtype.type, np.integer):
398399
dtype = np.dtype(np.float64)
399-
if not isna(fill_value):
400-
fill_value = dtype.type(fill_value)
401400

402401
elif dtype.kind == "f":
403-
if not np.can_cast(fill_value, dtype):
404-
# e.g. dtype is float32, need float64
405-
dtype = np.min_scalar_type(fill_value)
402+
mst = np.min_scalar_type(fill_value)
403+
if mst > dtype:
404+
# e.g. mst is np.float64 and dtype is np.float32
405+
dtype = mst
406406

407407
elif dtype.kind == "c":
408408
mst = np.min_scalar_type(fill_value)
409409
dtype = np.promote_types(dtype, mst)
410410

411-
if dtype.kind == "c" and not np.isnan(fill_value):
412-
fill_value = dtype.type(fill_value)
413-
414411
elif is_bool(fill_value):
415412
if not issubclass(dtype.type, np.bool_):
416-
dtype = np.object_
417-
else:
418-
fill_value = np.bool_(fill_value)
413+
dtype = np.dtype(np.object_)
414+
419415
elif is_integer(fill_value):
420416
if issubclass(dtype.type, np.bool_):
421417
dtype = np.dtype(np.object_)
418+
422419
elif issubclass(dtype.type, np.integer):
423420
if not np.can_cast(fill_value, dtype):
424421
# upcast to prevent overflow
@@ -428,35 +425,20 @@ def maybe_promote(dtype, fill_value=np.nan):
428425
# Case where we disagree with numpy
429426
dtype = np.dtype(np.object_)
430427

431-
fill_value = dtype.type(fill_value)
432-
433-
elif issubclass(dtype.type, np.floating):
434-
# check if we can cast
435-
if _check_lossless_cast(fill_value, dtype):
436-
fill_value = dtype.type(fill_value)
437-
438-
if dtype.kind in ["c", "f"]:
439-
# e.g. if dtype is complex128 and fill_value is 1, we
440-
# want np.complex128(1)
441-
fill_value = dtype.type(fill_value)
442-
443428
elif is_complex(fill_value):
444429
if issubclass(dtype.type, np.bool_):
445430
dtype = np.dtype(np.object_)
431+
446432
elif issubclass(dtype.type, (np.integer, np.floating)):
447433
mst = np.min_scalar_type(fill_value)
448434
dtype = np.promote_types(dtype, mst)
449435

450436
elif dtype.kind == "c":
451437
mst = np.min_scalar_type(fill_value)
452-
if mst > dtype and mst.kind == "c":
438+
if mst > dtype:
453439
# e.g. mst is np.complex128 and dtype is np.complex64
454440
dtype = mst
455441

456-
if dtype.kind == "c":
457-
# make sure we have a np.complex and not python complex
458-
fill_value = dtype.type(fill_value)
459-
460442
elif fill_value is None:
461443
if is_float_dtype(dtype) or is_complex_dtype(dtype):
462444
fill_value = np.nan
@@ -466,37 +448,48 @@ def maybe_promote(dtype, fill_value=np.nan):
466448
elif is_datetime_or_timedelta_dtype(dtype):
467449
fill_value = dtype.type("NaT", "ns")
468450
else:
469-
dtype = np.object_
451+
dtype = np.dtype(np.object_)
470452
fill_value = np.nan
471453
else:
472-
dtype = np.object_
454+
dtype = np.dtype(np.object_)
473455

474456
# in case we have a string that looked like a number
475457
if is_extension_array_dtype(dtype):
476458
pass
477459
elif issubclass(np.dtype(dtype).type, (bytes, str)):
478-
dtype = np.object_
460+
dtype = np.dtype(np.object_)
479461

462+
fill_value = _ensure_dtype_type(fill_value, dtype)
480463
return dtype, fill_value
481464

482465

483-
def _check_lossless_cast(value, dtype: np.dtype) -> bool:
466+
def _ensure_dtype_type(value, dtype):
484467
"""
485-
Check if we can cast the given value to the given dtype _losslesly_.
468+
Ensure that the given value is an instance of the given dtype.
469+
470+
e.g. if out dtype is np.complex64, we should have an instance of that
471+
as opposed to a python complex object.
486472
487473
Parameters
488474
----------
489475
value : object
490-
dtype : np.dtype
476+
dtype : np.dtype or ExtensionDtype
491477
492478
Returns
493479
-------
494-
bool
480+
object
495481
"""
496-
casted = dtype.type(value)
497-
if casted == value:
498-
return True
499-
return False
482+
483+
# Start with exceptions in which we do _not_ cast to numpy types
484+
if is_extension_array_dtype(dtype):
485+
return value
486+
elif dtype == np.object_:
487+
return value
488+
elif isna(value):
489+
# e.g. keep np.nan rather than try to cast to np.float32(np.nan)
490+
return value
491+
492+
return dtype.type(value)
500493

501494

502495
def infer_dtype_from(val, pandas_dtype=False):

0 commit comments

Comments
 (0)