@@ -398,14 +398,30 @@ def maybe_promote(dtype, fill_value=np.nan):
398
398
dtype = np .dtype (np .float64 )
399
399
if not isna (fill_value ):
400
400
fill_value = dtype .type (fill_value )
401
+
402
+ elif dtype .kind == "f" :
403
+ if not np .can_cast (fill_value , dtype ):
404
+ # e.g. dtype is float32, need float64
405
+ dtype = np .min_scalar_type (fill_value )
406
+
407
+ elif dtype .kind == "c" :
408
+ if not np .can_cast (fill_value , dtype ):
409
+ if np .can_cast (fill_value , np .dtype ("c16" )):
410
+ dtype = np .dtype (np .complex128 )
411
+ else :
412
+ dtype = np .dtype (np .object_ )
413
+
414
+ if dtype .kind == "c" and not np .isnan (fill_value ):
415
+ fill_value = dtype .type (fill_value )
416
+
401
417
elif is_bool (fill_value ):
402
418
if not issubclass (dtype .type , np .bool_ ):
403
419
dtype = np .object_
404
420
else :
405
421
fill_value = np .bool_ (fill_value )
406
422
elif is_integer (fill_value ):
407
423
if issubclass (dtype .type , np .bool_ ):
408
- dtype = np .object_
424
+ dtype = np .dtype ( np . object_ )
409
425
elif issubclass (dtype .type , np .integer ):
410
426
# upcast to prevent overflow
411
427
arr = np .asarray (fill_value )
@@ -415,11 +431,37 @@ def maybe_promote(dtype, fill_value=np.nan):
415
431
# check if we can cast
416
432
if _check_lossless_cast (fill_value , dtype ):
417
433
fill_value = dtype .type (fill_value )
434
+
435
+ if dtype .kind in ["c" , "f" ]:
436
+ # e.g. if dtype is complex128 and fill_value is 1, we
437
+ # want np.complex128(1)
438
+ fill_value = dtype .type (fill_value )
439
+
418
440
elif is_complex (fill_value ):
419
441
if issubclass (dtype .type , np .bool_ ):
420
- dtype = np .object_
442
+ dtype = np .dtype ( np . object_ )
421
443
elif issubclass (dtype .type , (np .integer , np .floating )):
422
- dtype = np .complex128
444
+ c8 = np .dtype (np .complex64 )
445
+ info = np .finfo (dtype ) if dtype .kind == "f" else np .iinfo (dtype )
446
+ if (
447
+ np .can_cast (fill_value , c8 )
448
+ and np .can_cast (info .min , c8 )
449
+ and np .can_cast (info .max , c8 )
450
+ ):
451
+ dtype = np .dtype (np .complex64 )
452
+ else :
453
+ dtype = np .dtype (np .complex128 )
454
+
455
+ elif dtype .kind == "c" :
456
+ mst = np .min_scalar_type (fill_value )
457
+ if mst > dtype and mst .kind == "c" :
458
+ # e.g. mst is np.complex128 and dtype is np.complex64
459
+ dtype = mst
460
+
461
+ if dtype .kind == "c" :
462
+ # make sure we have a np.complex and not python complex
463
+ fill_value = dtype .type (fill_value )
464
+
423
465
elif fill_value is None :
424
466
if is_float_dtype (dtype ) or is_complex_dtype (dtype ):
425
467
fill_value = np .nan
0 commit comments