Skip to content

Commit 9c0f6a8

Browse files
h-vetinarijreback
authored andcommitted
CLN: (re-)enable infer_dtype to catch complex (pandas-dev#25382)
1 parent 5449279 commit 9c0f6a8

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pandas/_libs/lib.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ _TYPE_MAP = {
939939
'float32': 'floating',
940940
'float64': 'floating',
941941
'f': 'floating',
942+
'complex64': 'complex',
942943
'complex128': 'complex',
943944
'c': 'complex',
944945
'string': 'string' if PY2 else 'bytes',
@@ -1305,6 +1306,9 @@ def infer_dtype(value: object, skipna: object=None) -> str:
13051306
elif is_decimal(val):
13061307
return 'decimal'
13071308

1309+
elif is_complex(val):
1310+
return 'complex'
1311+
13081312
elif util.is_float_object(val):
13091313
if is_float_array(values):
13101314
return 'floating'

pandas/tests/dtypes/test_inference.py

+31
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,37 @@ def test_decimals(self):
618618
result = lib.infer_dtype(arr, skipna=True)
619619
assert result == 'decimal'
620620

621+
# complex is compatible with nan, so skipna has no effect
622+
@pytest.mark.parametrize('skipna', [True, False])
623+
def test_complex(self, skipna):
624+
# gets cast to complex on array construction
625+
arr = np.array([1.0, 2.0, 1 + 1j])
626+
result = lib.infer_dtype(arr, skipna=skipna)
627+
assert result == 'complex'
628+
629+
arr = np.array([1.0, 2.0, 1 + 1j], dtype='O')
630+
result = lib.infer_dtype(arr, skipna=skipna)
631+
assert result == 'mixed'
632+
633+
# gets cast to complex on array construction
634+
arr = np.array([1, np.nan, 1 + 1j])
635+
result = lib.infer_dtype(arr, skipna=skipna)
636+
assert result == 'complex'
637+
638+
arr = np.array([1.0, np.nan, 1 + 1j], dtype='O')
639+
result = lib.infer_dtype(arr, skipna=skipna)
640+
assert result == 'mixed'
641+
642+
# complex with nans stays complex
643+
arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype='O')
644+
result = lib.infer_dtype(arr, skipna=skipna)
645+
assert result == 'complex'
646+
647+
# test smaller complex dtype; will pass through _try_infer_map fastpath
648+
arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype=np.complex64)
649+
result = lib.infer_dtype(arr, skipna=skipna)
650+
assert result == 'complex'
651+
621652
def test_string(self):
622653
pass
623654

0 commit comments

Comments
 (0)