Skip to content

Commit f9b9bd7

Browse files
jbrockmendelJulianWgs
authored andcommitted
BUG: lib.infer_dtype with incompatible intervals (pandas-dev#41749)
1 parent 136b20c commit f9b9bd7

File tree

4 files changed

+93
-22
lines changed

4 files changed

+93
-22
lines changed

pandas/_libs/lib.pyx

+51-8
Original file line numberDiff line numberDiff line change
@@ -2031,16 +2031,59 @@ cdef bint is_period_array(ndarray[object] values):
20312031
return True
20322032

20332033

2034-
cdef class IntervalValidator(Validator):
2035-
cdef inline bint is_value_typed(self, object value) except -1:
2036-
return is_interval(value)
2037-
2038-
20392034
cpdef bint is_interval_array(ndarray values):
2035+
"""
2036+
Is this an ndarray of Interval (or np.nan) with a single dtype?
2037+
"""
2038+
20402039
cdef:
2041-
IntervalValidator validator = IntervalValidator(len(values),
2042-
skipna=True)
2043-
return validator.validate(values)
2040+
Py_ssize_t i, n = len(values)
2041+
str closed = None
2042+
bint numeric = False
2043+
bint dt64 = False
2044+
bint td64 = False
2045+
object val
2046+
2047+
if len(values) == 0:
2048+
return False
2049+
2050+
for val in values:
2051+
if is_interval(val):
2052+
if closed is None:
2053+
closed = val.closed
2054+
numeric = (
2055+
util.is_float_object(val.left)
2056+
or util.is_integer_object(val.left)
2057+
)
2058+
td64 = is_timedelta(val.left)
2059+
dt64 = PyDateTime_Check(val.left)
2060+
elif val.closed != closed:
2061+
# mismatched closedness
2062+
return False
2063+
elif numeric:
2064+
if not (
2065+
util.is_float_object(val.left)
2066+
or util.is_integer_object(val.left)
2067+
):
2068+
# i.e. datetime64 or timedelta64
2069+
return False
2070+
elif td64:
2071+
if not is_timedelta(val.left):
2072+
return False
2073+
elif dt64:
2074+
if not PyDateTime_Check(val.left):
2075+
return False
2076+
else:
2077+
raise ValueError(val)
2078+
elif util.is_nan(val) or val is None:
2079+
pass
2080+
else:
2081+
return False
2082+
2083+
if closed is None:
2084+
# we saw all-NAs, no actual Intervals
2085+
return False
2086+
return True
20442087

20452088

20462089
@cython.boundscheck(False)

pandas/core/construction.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,7 @@ def array(
318318
return PeriodArray._from_sequence(data, copy=copy)
319319

320320
elif inferred_dtype == "interval":
321-
try:
322-
return IntervalArray(data, copy=copy)
323-
except ValueError:
324-
# We may have a mixture of `closed` here.
325-
# We choose to return an ndarray, rather than raising.
326-
pass
321+
return IntervalArray(data, copy=copy)
327322

328323
elif inferred_dtype.startswith("datetime"):
329324
# datetime, datetime64

pandas/core/indexes/base.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -6443,12 +6443,8 @@ def _maybe_cast_data_without_dtype(subarr: np.ndarray) -> ArrayLike:
64436443
return data
64446444

64456445
elif inferred == "interval":
6446-
try:
6447-
ia_data = IntervalArray._from_sequence(subarr, copy=False)
6448-
return ia_data
6449-
except (ValueError, TypeError):
6450-
# GH27172: mixed closed Intervals --> object dtype
6451-
pass
6446+
ia_data = IntervalArray._from_sequence(subarr, copy=False)
6447+
return ia_data
64526448
elif inferred == "boolean":
64536449
# don't support boolean explicitly ATM
64546450
pass

pandas/tests/dtypes/test_inference.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -1458,17 +1458,54 @@ def test_categorical(self):
14581458
result = lib.infer_dtype(Series(arr), skipna=True)
14591459
assert result == "categorical"
14601460

1461-
def test_interval(self):
1461+
@pytest.mark.parametrize("asobject", [True, False])
1462+
def test_interval(self, asobject):
14621463
idx = pd.IntervalIndex.from_breaks(range(5), closed="both")
1464+
if asobject:
1465+
idx = idx.astype(object)
1466+
14631467
inferred = lib.infer_dtype(idx, skipna=False)
14641468
assert inferred == "interval"
14651469

14661470
inferred = lib.infer_dtype(idx._data, skipna=False)
14671471
assert inferred == "interval"
14681472

1469-
inferred = lib.infer_dtype(Series(idx), skipna=False)
1473+
inferred = lib.infer_dtype(Series(idx, dtype=idx.dtype), skipna=False)
14701474
assert inferred == "interval"
14711475

1476+
@pytest.mark.parametrize("value", [Timestamp(0), Timedelta(0), 0, 0.0])
1477+
def test_interval_mismatched_closed(self, value):
1478+
1479+
first = Interval(value, value, closed="left")
1480+
second = Interval(value, value, closed="right")
1481+
1482+
# if closed match, we should infer "interval"
1483+
arr = np.array([first, first], dtype=object)
1484+
assert lib.infer_dtype(arr, skipna=False) == "interval"
1485+
1486+
# if closed dont match, we should _not_ get "interval"
1487+
arr2 = np.array([first, second], dtype=object)
1488+
assert lib.infer_dtype(arr2, skipna=False) == "mixed"
1489+
1490+
def test_interval_mismatched_subtype(self):
1491+
first = Interval(0, 1, closed="left")
1492+
second = Interval(Timestamp(0), Timestamp(1), closed="left")
1493+
third = Interval(Timedelta(0), Timedelta(1), closed="left")
1494+
1495+
arr = np.array([first, second])
1496+
assert lib.infer_dtype(arr, skipna=False) == "mixed"
1497+
1498+
arr = np.array([second, third])
1499+
assert lib.infer_dtype(arr, skipna=False) == "mixed"
1500+
1501+
arr = np.array([first, third])
1502+
assert lib.infer_dtype(arr, skipna=False) == "mixed"
1503+
1504+
# float vs int subdtype are compatible
1505+
flt_interval = Interval(1.5, 2.5, closed="left")
1506+
arr = np.array([first, flt_interval], dtype=object)
1507+
assert lib.infer_dtype(arr, skipna=False) == "interval"
1508+
14721509
@pytest.mark.parametrize("klass", [pd.array, Series])
14731510
@pytest.mark.parametrize("skipna", [True, False])
14741511
@pytest.mark.parametrize("data", [["a", "b", "c"], ["a", "b", pd.NA]])

0 commit comments

Comments
 (0)