@@ -393,32 +393,29 @@ def maybe_promote(dtype, fill_value=np.nan):
393
393
394
394
elif is_float (fill_value ):
395
395
if issubclass (dtype .type , np .bool_ ):
396
- dtype = np .object_
396
+ dtype = np .dtype (np .object_ )
397
+
397
398
elif issubclass (dtype .type , np .integer ):
398
399
dtype = np .dtype (np .float64 )
399
- if not isna (fill_value ):
400
- fill_value = dtype .type (fill_value )
401
400
402
401
elif dtype .kind == "f" :
403
- if not np .can_cast (fill_value , dtype ):
404
- # e.g. dtype is float32, need float64
405
- dtype = np .min_scalar_type (fill_value )
402
+ mst = np .min_scalar_type (fill_value )
403
+ if mst > dtype :
404
+ # e.g. mst is np.float64 and dtype is np.float32
405
+ dtype = mst
406
406
407
407
elif dtype .kind == "c" :
408
408
mst = np .min_scalar_type (fill_value )
409
409
dtype = np .promote_types (dtype , mst )
410
410
411
- if dtype .kind == "c" and not np .isnan (fill_value ):
412
- fill_value = dtype .type (fill_value )
413
-
414
411
elif is_bool (fill_value ):
415
412
if not issubclass (dtype .type , np .bool_ ):
416
- dtype = np .object_
417
- else :
418
- fill_value = np .bool_ (fill_value )
413
+ dtype = np .dtype (np .object_ )
414
+
419
415
elif is_integer (fill_value ):
420
416
if issubclass (dtype .type , np .bool_ ):
421
417
dtype = np .dtype (np .object_ )
418
+
422
419
elif issubclass (dtype .type , np .integer ):
423
420
if not np .can_cast (fill_value , dtype ):
424
421
# upcast to prevent overflow
@@ -428,35 +425,20 @@ def maybe_promote(dtype, fill_value=np.nan):
428
425
# Case where we disagree with numpy
429
426
dtype = np .dtype (np .object_ )
430
427
431
- fill_value = dtype .type (fill_value )
432
-
433
- elif issubclass (dtype .type , np .floating ):
434
- # check if we can cast
435
- if _check_lossless_cast (fill_value , dtype ):
436
- fill_value = dtype .type (fill_value )
437
-
438
- if dtype .kind in ["c" , "f" ]:
439
- # e.g. if dtype is complex128 and fill_value is 1, we
440
- # want np.complex128(1)
441
- fill_value = dtype .type (fill_value )
442
-
443
428
elif is_complex (fill_value ):
444
429
if issubclass (dtype .type , np .bool_ ):
445
430
dtype = np .dtype (np .object_ )
431
+
446
432
elif issubclass (dtype .type , (np .integer , np .floating )):
447
433
mst = np .min_scalar_type (fill_value )
448
434
dtype = np .promote_types (dtype , mst )
449
435
450
436
elif dtype .kind == "c" :
451
437
mst = np .min_scalar_type (fill_value )
452
- if mst > dtype and mst . kind == "c" :
438
+ if mst > dtype :
453
439
# e.g. mst is np.complex128 and dtype is np.complex64
454
440
dtype = mst
455
441
456
- if dtype .kind == "c" :
457
- # make sure we have a np.complex and not python complex
458
- fill_value = dtype .type (fill_value )
459
-
460
442
elif fill_value is None :
461
443
if is_float_dtype (dtype ) or is_complex_dtype (dtype ):
462
444
fill_value = np .nan
@@ -466,37 +448,48 @@ def maybe_promote(dtype, fill_value=np.nan):
466
448
elif is_datetime_or_timedelta_dtype (dtype ):
467
449
fill_value = dtype .type ("NaT" , "ns" )
468
450
else :
469
- dtype = np .object_
451
+ dtype = np .dtype ( np . object_ )
470
452
fill_value = np .nan
471
453
else :
472
- dtype = np .object_
454
+ dtype = np .dtype ( np . object_ )
473
455
474
456
# in case we have a string that looked like a number
475
457
if is_extension_array_dtype (dtype ):
476
458
pass
477
459
elif issubclass (np .dtype (dtype ).type , (bytes , str )):
478
- dtype = np .object_
460
+ dtype = np .dtype ( np . object_ )
479
461
462
+ fill_value = _ensure_dtype_type (fill_value , dtype )
480
463
return dtype , fill_value
481
464
482
465
483
- def _check_lossless_cast (value , dtype : np . dtype ) -> bool :
466
+ def _ensure_dtype_type (value , dtype ) :
484
467
"""
485
- Check if we can cast the given value to the given dtype _losslesly_.
468
+ Ensure that the given value is an instance of the given dtype.
469
+
470
+ e.g. if out dtype is np.complex64, we should have an instance of that
471
+ as opposed to a python complex object.
486
472
487
473
Parameters
488
474
----------
489
475
value : object
490
- dtype : np.dtype
476
+ dtype : np.dtype or ExtensionDtype
491
477
492
478
Returns
493
479
-------
494
- bool
480
+ object
495
481
"""
496
- casted = dtype .type (value )
497
- if casted == value :
498
- return True
499
- return False
482
+
483
+ # Start with exceptions in which we do _not_ cast to numpy types
484
+ if is_extension_array_dtype (dtype ):
485
+ return value
486
+ elif dtype == np .object_ :
487
+ return value
488
+ elif isna (value ):
489
+ # e.g. keep np.nan rather than try to cast to np.float32(np.nan)
490
+ return value
491
+
492
+ return dtype .type (value )
500
493
501
494
502
495
def infer_dtype_from (val , pandas_dtype = False ):
0 commit comments