Skip to content

Commit 3cc2fae

Browse files
committed
API: fix corner case of lib.infer_dtype (#23422)
1 parent 37feec1 commit 3cc2fae

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

pandas/_libs/lib.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from tslibs.conversion cimport convert_to_tsobject
5757
from tslibs.timedeltas cimport convert_to_timedelta64
5858
from tslibs.timezones cimport get_timezone, tz_compare
5959

60-
from missing cimport (checknull,
60+
from missing cimport (checknull, isnaobj,
6161
is_null_datetime64, is_null_timedelta64, is_null_period)
6262

6363

@@ -1177,6 +1177,9 @@ def infer_dtype(object value, bint skipna=False):
11771177
values = construct_1d_object_array_from_listlike(value)
11781178

11791179
values = getattr(values, 'values', values)
1180+
if skipna:
1181+
values = values[~isnaobj(values)]
1182+
11801183
val = _try_infer_map(values)
11811184
if val is not None:
11821185
return val

pandas/_libs/missing.pxd

+6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# -*- coding: utf-8 -*-
22

3+
from numpy cimport ndarray, uint8_t
4+
5+
from tslibs.nattype cimport is_null_datetimelike
6+
37
cpdef bint checknull(object val)
48
cpdef bint checknull_old(object val)
59

10+
cpdef ndarray[uint8_t] isnaobj(ndarray arr)
11+
612
cdef bint is_null_datetime64(v)
713
cdef bint is_null_timedelta64(v)
814
cdef bint is_null_period(v)

pandas/_libs/missing.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ cdef inline bint _check_none_nan_inf_neginf(object val):
124124

125125
@cython.wraparound(False)
126126
@cython.boundscheck(False)
127-
def isnaobj(ndarray arr):
127+
cpdef ndarray[uint8_t] isnaobj(ndarray arr):
128128
"""
129129
Return boolean mask denoting which elements of a 1-D array are na-like,
130130
according to the criteria defined in `_check_all_nulls`:

pandas/tests/dtypes/test_inference.py

+16
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,22 @@ def test_unicode(self):
591591
expected = 'unicode' if PY2 else 'string'
592592
assert result == expected
593593

594+
@pytest.mark.parametrize('dtype, missing, skipna, expected', [
595+
(float, np.nan, False, 'floating'),
596+
(float, np.nan, True, 'floating'),
597+
(object, np.nan, False, 'floating'),
598+
(object, np.nan, True, 'empty'),
599+
(object, None, False, 'mixed'),
600+
(object, None, True, 'empty')
601+
])
602+
@pytest.mark.parametrize('box', [pd.Series, np.array])
603+
def test_object_empty(self, box, missing, dtype, skipna, expected):
604+
# GH 23421
605+
arr = box([missing, missing], dtype=dtype)
606+
607+
result = lib.infer_dtype(arr, skipna=skipna)
608+
assert result == expected
609+
594610
def test_datetime(self):
595611

596612
dates = [datetime(2012, 1, x) for x in range(1, 20)]

0 commit comments

Comments
 (0)