diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index c072bfeff4a72..3d1ab08336be8 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -758,7 +758,7 @@ Interval ^^^^^^^^ - Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`) -- +- Bug in ``pandas.core.dtypes.cast.infer_dtype_from_scalar`` where passing ``pandas_dtype=True`` did not infer :class:`IntervalDtype` (:issue:`30337`) Indexing ^^^^^^^^ diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index b398a197a4bc0..1ab21f18f3bdc 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -41,7 +41,7 @@ is_unsigned_integer_dtype, pandas_dtype, ) -from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype +from .dtypes import DatetimeTZDtype, ExtensionDtype, IntervalDtype, PeriodDtype from .generic import ( ABCDataFrame, ABCDatetimeArray, @@ -601,6 +601,9 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False): if lib.is_period(val): dtype = PeriodDtype(freq=val.freq) val = val.ordinal + elif lib.is_interval(val): + subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0] + dtype = IntervalDtype(subtype=subtype) return dtype, val diff --git a/pandas/tests/dtypes/cast/test_infer_dtype.py b/pandas/tests/dtypes/cast/test_infer_dtype.py index bf11b81af6f90..da2ef5260d070 100644 --- a/pandas/tests/dtypes/cast/test_infer_dtype.py +++ b/pandas/tests/dtypes/cast/test_infer_dtype.py @@ -10,7 +10,15 @@ ) from pandas.core.dtypes.common import is_dtype_equal -from pandas import Categorical, Period, Series, Timedelta, Timestamp, date_range +from pandas import ( + Categorical, + Interval, + Period, + Series, + Timedelta, + Timestamp, + date_range, +) import pandas.util.testing as tm @@ -107,6 +115,25 @@ def test_infer_from_scalar_tz(tz, pandas_dtype): assert val == exp_val +@pytest.mark.parametrize( + "left, right, subtype", + [ + (0, 1, "int64"), + (0.0, 1.0, "float64"), + (Timestamp(0), Timestamp(1), "datetime64[ns]"), + (Timestamp(0, tz="UTC"), Timestamp(1, tz="UTC"), "datetime64[ns, UTC]"), + (Timedelta(0), Timedelta(1), "timedelta64[ns]"), + ], +) +def test_infer_from_interval(left, right, subtype, closed, pandas_dtype): + # GH 30337 + interval = Interval(left, right, closed) + result_dtype, result_value = infer_dtype_from_scalar(interval, pandas_dtype) + expected_dtype = f"interval[{subtype}]" if pandas_dtype else np.object_ + assert result_dtype == expected_dtype + assert result_value == interval + + def test_infer_dtype_from_scalar_errors(): msg = "invalid ndarray passed to infer_dtype_from_scalar"