Skip to content

ENH: make closed part of IntervalDtype #37933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/arrays/_arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __hash__(self):
def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype())
return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.closed)

# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
Expand Down
10 changes: 6 additions & 4 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def __new__(
def _simple_new(cls, data, closed="right"):
result = IntervalMixin.__new__(cls)

dtype = IntervalDtype(data.dtype, closed=closed)
result._dtype = dtype

result._combined = data
result._left = data[:, 0]
result._right = data[:, 1]
result._closed = closed
return result

@classmethod
Expand Down Expand Up @@ -480,7 +482,7 @@ def _validate(self):

@property
def dtype(self):
return IntervalDtype(self.left.dtype)
return self._dtype

@property
def nbytes(self) -> int:
Expand Down Expand Up @@ -1117,7 +1119,7 @@ def closed(self):
Whether the intervals are closed on the left-side, right-side, both or
neither.
"""
return self._closed
return self.dtype.closed

_interval_shared_docs["set_closed"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -1212,7 +1214,7 @@ def __array__(self, dtype=None) -> np.ndarray:
left = self._left
right = self._right
mask = self.isna()
closed = self._closed
closed = self.closed

result = np.empty(len(left), dtype=object)
for i in range(len(left)):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> Tuple[DtypeObj,
dtype = PeriodDtype(freq=val.freq)
elif lib.is_interval(val):
subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0]
dtype = IntervalDtype(subtype=subtype)
dtype = IntervalDtype(subtype=subtype, closed=val.closed)

return dtype, val

Expand Down
28 changes: 22 additions & 6 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,27 +1008,33 @@ class IntervalDtype(PandasExtensionDtype):
base = np.dtype("O")
num = 103
_metadata = ("subtype",)
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
_match = re.compile(
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
)
_cache: Dict[str_type, PandasExtensionDtype] = {}

def __new__(cls, subtype=None):
def __new__(cls, subtype=None, closed=None):
from pandas.core.dtypes.common import is_string_dtype, pandas_dtype

if isinstance(subtype, IntervalDtype):
return subtype
# TODO: what if closed is also passed?
elif subtype is None:
# we are called as an empty constructor
# generally for pickle compat
u = object.__new__(cls)
u._subtype = None
u._closed = closed
return u
elif isinstance(subtype, str) and subtype.lower() == "interval":
subtype = None
else:
if isinstance(subtype, str):
m = cls._match.search(subtype)
if m is not None:
subtype = m.group("subtype")
gd = m.groupdict()
subtype = gd["subtype"]
closed = gd.get("closed", closed)

try:
subtype = pandas_dtype(subtype)
Expand All @@ -1043,14 +1049,20 @@ def __new__(cls, subtype=None):
)
raise TypeError(msg)

key = str(subtype) + str(closed)
try:
return cls._cache[str(subtype)]
return cls._cache[key]
except KeyError:
u = object.__new__(cls)
u._subtype = subtype
cls._cache[str(subtype)] = u
u._closed = closed
cls._cache[key] = u
return u

@property
def closed(self):
return self._closed

@property
def subtype(self):
"""
Expand Down Expand Up @@ -1100,7 +1112,7 @@ def type(self):
def __str__(self) -> str_type:
if self.subtype is None:
return "interval"
return f"interval[{self.subtype}]"
return f"interval[{self.subtype}, {self.closed}]"

def __hash__(self) -> int:
# make myself hashable
Expand All @@ -1114,6 +1126,8 @@ def __eq__(self, other: Any) -> bool:
elif self.subtype is None or other.subtype is None:
# None should match any subtype
return True
elif self.closed != other.closed:
return False
else:
from pandas.core.dtypes.common import is_dtype_equal

Expand All @@ -1124,6 +1138,8 @@ def __setstate__(self, state):
# PandasExtensionDtype superclass and uses the public properties to
# pickle -> need to set the settable private ones here (see GH26067)
self._subtype = state["subtype"]
# backward-compat older pickles won't have "closed" key
self._closed = state.pop("closed", None)

@classmethod
def is_dtype(cls, dtype: object) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_compare_scalar_na(self, op, array, nulls_fixture, request):
result = op(array, nulls_fixture)
expected = self.elementwise_comparison(op, array, nulls_fixture)

if nulls_fixture is pd.NA and array.dtype != pd.IntervalDtype("int64"):
if nulls_fixture is pd.NA and array.dtype.subtype != "int64":
mark = pytest.mark.xfail(
reason="broken for non-integer IntervalArray; see GH 31882"
)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_repr():
expected = (
"<IntervalArray>\n"
"[(0, 1], (1, 2]]\n"
"Length: 2, closed: right, dtype: interval[int64]"
"Length: 2, closed: right, dtype: interval[int64, right]"
)
assert result == expected

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/dtypes/cast/test_infer_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_infer_from_interval(left, right, subtype, closed, pandas_dtype):
# GH 30337
interval = Interval(left, right, closed)
result_dtype, result_value = infer_dtype_from_scalar(interval, pandas_dtype)
expected_dtype = f"interval[{subtype}]" if pandas_dtype else np.object_
expected_dtype = f"interval[{subtype}, {closed}]" if pandas_dtype else np.object_
assert result_dtype == expected_dtype
assert result_value == interval

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def test_equality_generic(self, subtype):
def test_name_repr(self, subtype):
# GH 18980
dtype = IntervalDtype(subtype)
expected = f"interval[{subtype}]"
expected = f"interval[{subtype}, None]"
assert str(dtype) == expected
assert dtype.name == "interval"

Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/frame/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ def test_constructor_period_dict(self):
[
(pd.Period("2012-01", freq="M"), "period[M]"),
(pd.Period("2012-02-01", freq="D"), "period[D]"),
(Interval(left=0, right=5), IntervalDtype("int64")),
(Interval(left=0.1, right=0.5), IntervalDtype("float64")),
(Interval(left=0, right=5), IntervalDtype("int64", "right")),
(Interval(left=0.1, right=0.5), IntervalDtype("float64", "right")),
],
)
def test_constructor_period_dict_scalar(self, data, dtype):
Expand All @@ -739,7 +739,7 @@ def test_constructor_period_dict_scalar(self, data, dtype):
"data,dtype",
[
(Period("2020-01"), PeriodDtype("M")),
(Interval(left=0, right=5), IntervalDtype("int64")),
(Interval(left=0, right=5), IntervalDtype("int64", "right")),
(
Timestamp("2011-01-01", tz="US/Eastern"),
DatetimeTZDtype(tz="US/Eastern"),
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/interval/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_constructor_dtype(self, constructor, breaks, subtype):
expected = constructor(**expected_kwargs)

result_kwargs = self.get_kwargs_from_breaks(breaks)
iv_dtype = IntervalDtype(subtype)
iv_dtype = IntervalDtype(subtype, "right")
for dtype in (iv_dtype, str(iv_dtype)):
result = constructor(dtype=dtype, **result_kwargs)
tm.assert_index_equal(result, expected)
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def test_comparison(self):
msg = "|".join(
[
"not supported between instances of 'int' and '.*.Interval'",
r"Invalid comparison between dtype=interval\[int64\] and ",
r"Invalid comparison between dtype=interval\[int64, right\] and ",
]
)
with pytest.raises(TypeError, match=msg):
Expand Down Expand Up @@ -691,13 +691,13 @@ def test_append(self, closed):
)
tm.assert_index_equal(result, expected)

msg = "Intervals must all be closed on the same side"
for other_closed in {"left", "right", "both", "neither"} - {closed}:
index_other_closed = IntervalIndex.from_arrays(
[0, 1], [1, 2], closed=other_closed
)
with pytest.raises(ValueError, match=msg):
index1.append(index_other_closed)
result = index1.append(index_other_closed)
expected = index1.astype(object).append(index_other_closed.astype(object))
tm.assert_index_equal(result, expected)

def test_is_non_overlapping_monotonic(self, closed):
# Should be True in all cases
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/series/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def test_construction_interval(self, interval_constructor):
# construction from interval & array of intervals
intervals = interval_constructor.from_breaks(np.arange(3), closed="right")
result = Series(intervals)
assert result.dtype == "interval[int64]"
assert result.dtype == "interval[int64, right]"
tm.assert_index_equal(Index(result.values), Index(intervals))

@pytest.mark.parametrize(
Expand All @@ -1008,7 +1008,7 @@ def test_constructor_infer_interval(self, data_constructor):
data = [Interval(0, 1), Interval(0, 2), None]
result = Series(data_constructor(data))
expected = Series(IntervalArray(data))
assert result.dtype == "interval[float64]"
assert result.dtype == "interval[float64, right]"
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def test_constructor_dict_order(self):
"data,dtype",
[
(Period("2020-01"), PeriodDtype("M")),
(Interval(left=0, right=5), IntervalDtype("int64")),
(Interval(left=0, right=5), IntervalDtype("int64", "right")),
(
Timestamp("2011-01-01", tz="US/Eastern"),
DatetimeTZDtype(tz="US/Eastern"),
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/util/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_assert_frame_equal_interval_dtype_mismatch():
"Attributes of DataFrame\\.iloc\\[:, 0\\] "
'\\(column name="a"\\) are different\n\n'
'Attribute "dtype" are different\n'
"\\[left\\]: interval\\[int64\\]\n"
"\\[left\\]: interval\\[int64, right\\]\n"
"\\[right\\]: object"
)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/util/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_assert_series_equal_interval_dtype_mismatch():
msg = """Attributes of Series are different

Attribute "dtype" are different
\\[left\\]: interval\\[int64\\]
\\[left\\]: interval\\[int64, right\\]
\\[right\\]: object"""

tm.assert_series_equal(left, right, check_dtype=False)
Expand Down