Skip to content

Commit 61693b5

Browse files
authored
BUG: infer_dtype with decimal/complex (#37176)
1 parent 3a2a190 commit 61693b5

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

pandas/_libs/lib.pyx

+32-4
Original file line numberDiff line numberDiff line change
@@ -1414,10 +1414,12 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
14141414
return "time"
14151415

14161416
elif is_decimal(val):
1417-
return "decimal"
1417+
if is_decimal_array(values):
1418+
return "decimal"
14181419

14191420
elif is_complex(val):
1420-
return "complex"
1421+
if is_complex_array(values):
1422+
return "complex"
14211423

14221424
elif util.is_float_object(val):
14231425
if is_float_array(values):
@@ -1702,6 +1704,34 @@ cpdef bint is_float_array(ndarray values):
17021704
return validator.validate(values)
17031705

17041706

1707+
cdef class ComplexValidator(Validator):
1708+
cdef inline bint is_value_typed(self, object value) except -1:
1709+
return (
1710+
util.is_complex_object(value)
1711+
or (util.is_float_object(value) and is_nan(value))
1712+
)
1713+
1714+
cdef inline bint is_array_typed(self) except -1:
1715+
return issubclass(self.dtype.type, np.complexfloating)
1716+
1717+
1718+
cdef bint is_complex_array(ndarray values):
1719+
cdef:
1720+
ComplexValidator validator = ComplexValidator(len(values), values.dtype)
1721+
return validator.validate(values)
1722+
1723+
1724+
cdef class DecimalValidator(Validator):
1725+
cdef inline bint is_value_typed(self, object value) except -1:
1726+
return is_decimal(value)
1727+
1728+
1729+
cdef bint is_decimal_array(ndarray values):
1730+
cdef:
1731+
DecimalValidator validator = DecimalValidator(len(values), values.dtype)
1732+
return validator.validate(values)
1733+
1734+
17051735
cdef class StringValidator(Validator):
17061736
cdef inline bint is_value_typed(self, object value) except -1:
17071737
return isinstance(value, str)
@@ -2546,8 +2576,6 @@ def fast_multiget(dict mapping, ndarray keys, default=np.nan):
25462576
# kludge, for Series
25472577
return np.empty(0, dtype='f8')
25482578

2549-
keys = getattr(keys, 'values', keys)
2550-
25512579
for i in range(n):
25522580
val = keys[i]
25532581
if val in mapping:

pandas/tests/dtypes/test_inference.py

+6
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,9 @@ def test_decimals(self):
709709
result = lib.infer_dtype(arr, skipna=True)
710710
assert result == "mixed"
711711

712+
result = lib.infer_dtype(arr[::-1], skipna=True)
713+
assert result == "mixed"
714+
712715
arr = np.array([Decimal(1), Decimal("NaN"), Decimal(3)])
713716
result = lib.infer_dtype(arr, skipna=True)
714717
assert result == "decimal"
@@ -729,6 +732,9 @@ def test_complex(self, skipna):
729732
result = lib.infer_dtype(arr, skipna=skipna)
730733
assert result == "mixed"
731734

735+
result = lib.infer_dtype(arr[::-1], skipna=skipna)
736+
assert result == "mixed"
737+
732738
# gets cast to complex on array construction
733739
arr = np.array([1, np.nan, 1 + 1j])
734740
result = lib.infer_dtype(arr, skipna=skipna)

0 commit comments

Comments
 (0)