From 848b155eb18cc75e59266cd6c2541b7b64b0101e Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 16 Oct 2020 14:56:51 -0700 Subject: [PATCH 1/2] BUG: infer_dtype with decimal/complex --- pandas/_libs/lib.pyx | 25 +++++++++++++++++++++---- pandas/tests/dtypes/test_inference.py | 6 ++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 922dcd7e74aa0..531e9b6920e0c 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -1414,10 +1414,12 @@ def infer_dtype(value: object, skipna: bool = True) -> str: return "time" elif is_decimal(val): - return "decimal" + if all(is_decimal(x) for x in values): + return "decimal" elif is_complex(val): - return "complex" + if is_complex_array(values): + return "complex" elif util.is_float_object(val): if is_float_array(values): @@ -1702,6 +1704,23 @@ cpdef bint is_float_array(ndarray values): return validator.validate(values) +cdef class ComplexValidator(Validator): + cdef inline bint is_value_typed(self, object value) except -1: + return ( + util.is_complex_object(value) + or (util.is_float_object(value) and is_nan(value)) + ) + + cdef inline bint is_array_typed(self) except -1: + return issubclass(self.dtype.type, np.complexfloating) + + +cdef bint is_complex_array(ndarray values): + cdef: + ComplexValidator validator = ComplexValidator(len(values), values.dtype) + return validator.validate(values) + + cdef class StringValidator(Validator): cdef inline bint is_value_typed(self, object value) except -1: return isinstance(value, str) @@ -2546,8 +2565,6 @@ def fast_multiget(dict mapping, ndarray keys, default=np.nan): # kludge, for Series return np.empty(0, dtype='f8') - keys = getattr(keys, 'values', keys) - for i in range(n): val = keys[i] if val in mapping: diff --git a/pandas/tests/dtypes/test_inference.py b/pandas/tests/dtypes/test_inference.py index c6c54ccb357d5..7fa83eeac8400 100644 --- a/pandas/tests/dtypes/test_inference.py +++ b/pandas/tests/dtypes/test_inference.py @@ -709,6 +709,9 @@ def test_decimals(self): result = lib.infer_dtype(arr, skipna=True) assert result == "mixed" + result = lib.infer_dtype(arr[::-1], skipna=True) + assert result == "mixed" + arr = np.array([Decimal(1), Decimal("NaN"), Decimal(3)]) result = lib.infer_dtype(arr, skipna=True) assert result == "decimal" @@ -729,6 +732,9 @@ def test_complex(self, skipna): result = lib.infer_dtype(arr, skipna=skipna) assert result == "mixed" + result = lib.infer_dtype(arr[::-1], skipna=skipna) + assert result == "mixed" + # gets cast to complex on array construction arr = np.array([1, np.nan, 1 + 1j]) result = lib.infer_dtype(arr, skipna=skipna) From 4597993da76d5d4d386fc0385bc914ca0c64b0ca Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 16 Oct 2020 18:30:46 -0700 Subject: [PATCH 2/2] implement is_decimal_array --- pandas/_libs/lib.pyx | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 531e9b6920e0c..f4caafb3a9fe7 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -1414,7 +1414,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str: return "time" elif is_decimal(val): - if all(is_decimal(x) for x in values): + if is_decimal_array(values): return "decimal" elif is_complex(val): @@ -1721,6 +1721,17 @@ cdef bint is_complex_array(ndarray values): return validator.validate(values) +cdef class DecimalValidator(Validator): + cdef inline bint is_value_typed(self, object value) except -1: + return is_decimal(value) + + +cdef bint is_decimal_array(ndarray values): + cdef: + DecimalValidator validator = DecimalValidator(len(values), values.dtype) + return validator.validate(values) + + cdef class StringValidator(Validator): cdef inline bint is_value_typed(self, object value) except -1: return isinstance(value, str)