Skip to content

Commit 53a0dfd

Browse files
jschendelTomAugspurger
authored andcommitted
BUG: Fix infer_dtype_from_scalar to infer IntervalDtype (#30339)
1 parent 20e4c18 commit 53a0dfd

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

doc/source/whatsnew/v1.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ Interval
758758
^^^^^^^^
759759

760760
- Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`)
761-
-
761+
- Bug in ``pandas.core.dtypes.cast.infer_dtype_from_scalar`` where passing ``pandas_dtype=True`` did not infer :class:`IntervalDtype` (:issue:`30337`)
762762

763763
Indexing
764764
^^^^^^^^

pandas/core/dtypes/cast.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
is_unsigned_integer_dtype,
4242
pandas_dtype,
4343
)
44-
from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype
44+
from .dtypes import DatetimeTZDtype, ExtensionDtype, IntervalDtype, PeriodDtype
4545
from .generic import (
4646
ABCDataFrame,
4747
ABCDatetimeArray,
@@ -601,6 +601,9 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False):
601601
if lib.is_period(val):
602602
dtype = PeriodDtype(freq=val.freq)
603603
val = val.ordinal
604+
elif lib.is_interval(val):
605+
subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0]
606+
dtype = IntervalDtype(subtype=subtype)
604607

605608
return dtype, val
606609

pandas/tests/dtypes/cast/test_infer_dtype.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
)
1111
from pandas.core.dtypes.common import is_dtype_equal
1212

13-
from pandas import Categorical, Period, Series, Timedelta, Timestamp, date_range
13+
from pandas import (
14+
Categorical,
15+
Interval,
16+
Period,
17+
Series,
18+
Timedelta,
19+
Timestamp,
20+
date_range,
21+
)
1422
import pandas.util.testing as tm
1523

1624

@@ -107,6 +115,25 @@ def test_infer_from_scalar_tz(tz, pandas_dtype):
107115
assert val == exp_val
108116

109117

118+
@pytest.mark.parametrize(
119+
"left, right, subtype",
120+
[
121+
(0, 1, "int64"),
122+
(0.0, 1.0, "float64"),
123+
(Timestamp(0), Timestamp(1), "datetime64[ns]"),
124+
(Timestamp(0, tz="UTC"), Timestamp(1, tz="UTC"), "datetime64[ns, UTC]"),
125+
(Timedelta(0), Timedelta(1), "timedelta64[ns]"),
126+
],
127+
)
128+
def test_infer_from_interval(left, right, subtype, closed, pandas_dtype):
129+
# GH 30337
130+
interval = Interval(left, right, closed)
131+
result_dtype, result_value = infer_dtype_from_scalar(interval, pandas_dtype)
132+
expected_dtype = f"interval[{subtype}]" if pandas_dtype else np.object_
133+
assert result_dtype == expected_dtype
134+
assert result_value == interval
135+
136+
110137
def test_infer_dtype_from_scalar_errors():
111138
msg = "invalid ndarray passed to infer_dtype_from_scalar"
112139

0 commit comments

Comments
 (0)