Skip to content

Commit 839c1bd

Browse files
authored
ENH: make "closed" part of IntervalDtype (#38394)
1 parent 2d08672 commit 839c1bd

22 files changed

+252
-102
lines changed

pandas/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,8 @@ def float_frame():
703703
# ----------------------------------------------------------------
704704
@pytest.fixture(
705705
params=[
706-
(Interval(left=0, right=5), IntervalDtype("int64")),
707-
(Interval(left=0.1, right=0.5), IntervalDtype("float64")),
706+
(Interval(left=0, right=5), IntervalDtype("int64", "right")),
707+
(Interval(left=0.1, right=0.5), IntervalDtype("float64", "right")),
708708
(Period("2012-01", freq="M"), "period[M]"),
709709
(Period("2012-02-01", freq="D"), "period[D]"),
710710
(

pandas/core/arrays/_arrow_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __hash__(self):
127127
def to_pandas_dtype(self):
128128
import pandas as pd
129129

130-
return pd.IntervalDtype(self.subtype.to_pandas_dtype())
130+
return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.closed)
131131

132132
# register the type with a dummy instance
133133
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")

pandas/core/arrays/interval.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
>>> pd.arrays.IntervalArray([pd.Interval(0, 1), pd.Interval(1, 5)])
150150
<IntervalArray>
151151
[(0, 1], (1, 5]]
152-
Length: 2, closed: right, dtype: interval[int64]
152+
Length: 2, closed: right, dtype: interval[int64, right]
153153
154154
It may also be constructed using one of the constructor
155155
methods: :meth:`IntervalArray.from_arrays`,
@@ -222,6 +222,9 @@ def _simple_new(
222222
):
223223
result = IntervalMixin.__new__(cls)
224224

225+
if closed is None and isinstance(dtype, IntervalDtype):
226+
closed = dtype.closed
227+
225228
closed = closed or "right"
226229
left = ensure_index(left, copy=copy)
227230
right = ensure_index(right, copy=copy)
@@ -238,6 +241,12 @@ def _simple_new(
238241
msg = f"dtype must be an IntervalDtype, got {dtype}"
239242
raise TypeError(msg)
240243

244+
if dtype.closed is None:
245+
# possibly loading an old pickle
246+
dtype = IntervalDtype(dtype.subtype, closed)
247+
elif closed != dtype.closed:
248+
raise ValueError("closed keyword does not match dtype.closed")
249+
241250
# coerce dtypes to match if needed
242251
if is_float_dtype(left) and is_integer_dtype(right):
243252
right = right.astype(left.dtype)
@@ -279,9 +288,11 @@ def _simple_new(
279288
# If these share data, then setitem could corrupt our IA
280289
right = right.copy()
281290

291+
dtype = IntervalDtype(left.dtype, closed=closed)
292+
result._dtype = dtype
293+
282294
result._left = left
283295
result._right = right
284-
result._closed = closed
285296
if verify_integrity:
286297
result._validate()
287298
return result
@@ -343,7 +354,7 @@ def _from_factorized(cls, values, original):
343354
>>> pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3])
344355
<IntervalArray>
345356
[(0, 1], (1, 2], (2, 3]]
346-
Length: 3, closed: right, dtype: interval[int64]
357+
Length: 3, closed: right, dtype: interval[int64, right]
347358
"""
348359
),
349360
}
@@ -414,7 +425,7 @@ def from_breaks(
414425
>>> pd.arrays.IntervalArray.from_arrays([0, 1, 2], [1, 2, 3])
415426
<IntervalArray>
416427
[(0, 1], (1, 2], (2, 3]]
417-
Length: 3, closed: right, dtype: interval[int64]
428+
Length: 3, closed: right, dtype: interval[int64, right]
418429
"""
419430
),
420431
}
@@ -473,7 +484,7 @@ def from_arrays(
473484
>>> pd.arrays.IntervalArray.from_tuples([(0, 1), (1, 2)])
474485
<IntervalArray>
475486
[(0, 1], (1, 2]]
476-
Length: 2, closed: right, dtype: interval[int64]
487+
Length: 2, closed: right, dtype: interval[int64, right]
477488
"""
478489
),
479490
}
@@ -553,7 +564,7 @@ def _shallow_copy(self, left, right):
553564

554565
@property
555566
def dtype(self):
556-
return IntervalDtype(self.left.dtype)
567+
return self._dtype
557568

558569
@property
559570
def nbytes(self) -> int:
@@ -1174,7 +1185,7 @@ def mid(self):
11741185
>>> intervals
11751186
<IntervalArray>
11761187
[(0, 1], (1, 3], (2, 4]]
1177-
Length: 3, closed: right, dtype: interval[int64]
1188+
Length: 3, closed: right, dtype: interval[int64, right]
11781189
"""
11791190
),
11801191
}
@@ -1203,7 +1214,7 @@ def closed(self):
12031214
Whether the intervals are closed on the left-side, right-side, both or
12041215
neither.
12051216
"""
1206-
return self._closed
1217+
return self.dtype.closed
12071218

12081219
_interval_shared_docs["set_closed"] = textwrap.dedent(
12091220
"""
@@ -1238,11 +1249,11 @@ def closed(self):
12381249
>>> index
12391250
<IntervalArray>
12401251
[(0, 1], (1, 2], (2, 3]]
1241-
Length: 3, closed: right, dtype: interval[int64]
1252+
Length: 3, closed: right, dtype: interval[int64, right]
12421253
>>> index.set_closed('both')
12431254
<IntervalArray>
12441255
[[0, 1], [1, 2], [2, 3]]
1245-
Length: 3, closed: both, dtype: interval[int64]
1256+
Length: 3, closed: both, dtype: interval[int64, both]
12461257
"""
12471258
),
12481259
}
@@ -1301,7 +1312,7 @@ def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
13011312
left = self._left
13021313
right = self._right
13031314
mask = self.isna()
1304-
closed = self._closed
1315+
closed = self.closed
13051316

13061317
result = np.empty(len(left), dtype=object)
13071318
for i in range(len(left)):
@@ -1441,7 +1452,7 @@ def repeat(self, repeats, axis=None):
14411452
>>> intervals
14421453
<IntervalArray>
14431454
[(0, 1], (1, 3], (2, 4]]
1444-
Length: 3, closed: right, dtype: interval[int64]
1455+
Length: 3, closed: right, dtype: interval[int64, right]
14451456
"""
14461457
),
14471458
}

pandas/core/dtypes/cast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> Tuple[DtypeObj,
779779
dtype = PeriodDtype(freq=val.freq)
780780
elif lib.is_interval(val):
781781
subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0]
782-
dtype = IntervalDtype(subtype=subtype)
782+
dtype = IntervalDtype(subtype=subtype, closed=val.closed)
783783

784784
return dtype, val
785785

pandas/core/dtypes/dtypes.py

+51-10
Original file line numberDiff line numberDiff line change
@@ -999,37 +999,60 @@ class IntervalDtype(PandasExtensionDtype):
999999
10001000
Examples
10011001
--------
1002-
>>> pd.IntervalDtype(subtype='int64')
1003-
interval[int64]
1002+
>>> pd.IntervalDtype(subtype='int64', closed='both')
1003+
interval[int64, both]
10041004
"""
10051005

10061006
name = "interval"
10071007
kind: str_type = "O"
10081008
str = "|O08"
10091009
base = np.dtype("O")
10101010
num = 103
1011-
_metadata = ("subtype",)
1012-
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
1011+
_metadata = (
1012+
"subtype",
1013+
"closed",
1014+
)
1015+
_match = re.compile(
1016+
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
1017+
)
10131018
_cache: Dict[str_type, PandasExtensionDtype] = {}
10141019

1015-
def __new__(cls, subtype=None):
1020+
def __new__(cls, subtype=None, closed: Optional[str_type] = None):
10161021
from pandas.core.dtypes.common import is_string_dtype, pandas_dtype
10171022

1023+
if closed is not None and closed not in {"right", "left", "both", "neither"}:
1024+
raise ValueError("closed must be one of 'right', 'left', 'both', 'neither'")
1025+
10181026
if isinstance(subtype, IntervalDtype):
1027+
if closed is not None and closed != subtype.closed:
1028+
raise ValueError(
1029+
"dtype.closed and 'closed' do not match. "
1030+
"Try IntervalDtype(dtype.subtype, closed) instead."
1031+
)
10191032
return subtype
10201033
elif subtype is None:
10211034
# we are called as an empty constructor
10221035
# generally for pickle compat
10231036
u = object.__new__(cls)
10241037
u._subtype = None
1038+
u._closed = closed
10251039
return u
10261040
elif isinstance(subtype, str) and subtype.lower() == "interval":
10271041
subtype = None
10281042
else:
10291043
if isinstance(subtype, str):
10301044
m = cls._match.search(subtype)
10311045
if m is not None:
1032-
subtype = m.group("subtype")
1046+
gd = m.groupdict()
1047+
subtype = gd["subtype"]
1048+
if gd.get("closed", None) is not None:
1049+
if closed is not None:
1050+
if closed != gd["closed"]:
1051+
raise ValueError(
1052+
"'closed' keyword does not match value "
1053+
"specified in dtype string"
1054+
)
1055+
closed = gd["closed"]
10331056

10341057
try:
10351058
subtype = pandas_dtype(subtype)
@@ -1044,14 +1067,20 @@ def __new__(cls, subtype=None):
10441067
)
10451068
raise TypeError(msg)
10461069

1070+
key = str(subtype) + str(closed)
10471071
try:
1048-
return cls._cache[str(subtype)]
1072+
return cls._cache[key]
10491073
except KeyError:
10501074
u = object.__new__(cls)
10511075
u._subtype = subtype
1052-
cls._cache[str(subtype)] = u
1076+
u._closed = closed
1077+
cls._cache[key] = u
10531078
return u
10541079

1080+
@property
1081+
def closed(self):
1082+
return self._closed
1083+
10551084
@property
10561085
def subtype(self):
10571086
"""
@@ -1101,7 +1130,10 @@ def type(self):
11011130
def __str__(self) -> str_type:
11021131
if self.subtype is None:
11031132
return "interval"
1104-
return f"interval[{self.subtype}]"
1133+
if self.closed is None:
1134+
# Only partially initialized GH#38394
1135+
return f"interval[{self.subtype}]"
1136+
return f"interval[{self.subtype}, {self.closed}]"
11051137

11061138
def __hash__(self) -> int:
11071139
# make myself hashable
@@ -1115,6 +1147,8 @@ def __eq__(self, other: Any) -> bool:
11151147
elif self.subtype is None or other.subtype is None:
11161148
# None should match any subtype
11171149
return True
1150+
elif self.closed != other.closed:
1151+
return False
11181152
else:
11191153
from pandas.core.dtypes.common import is_dtype_equal
11201154

@@ -1126,6 +1160,9 @@ def __setstate__(self, state):
11261160
# pickle -> need to set the settable private ones here (see GH26067)
11271161
self._subtype = state["subtype"]
11281162

1163+
# backward-compat older pickles won't have "closed" key
1164+
self._closed = state.pop("closed", None)
1165+
11291166
@classmethod
11301167
def is_dtype(cls, dtype: object) -> bool:
11311168
"""
@@ -1174,9 +1211,13 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
11741211
if not all(isinstance(x, IntervalDtype) for x in dtypes):
11751212
return None
11761213

1214+
closed = cast("IntervalDtype", dtypes[0]).closed
1215+
if not all(cast("IntervalDtype", x).closed == closed for x in dtypes):
1216+
return np.dtype(object)
1217+
11771218
from pandas.core.dtypes.cast import find_common_type
11781219

11791220
common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes])
11801221
if common == object:
11811222
return np.dtype(object)
1182-
return IntervalDtype(common)
1223+
return IntervalDtype(common, closed=closed)

0 commit comments

Comments
 (0)