Skip to content

Commit fc14379

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
REF: MultiIndex._validate_insert_value, IntervaArray._validate_setitem_value (pandas-dev#36461)
1 parent 70f2846 commit fc14379

File tree

2 files changed

+46
-38
lines changed

2 files changed

+46
-38
lines changed

pandas/core/arrays/interval.py

+36-32
Original file line numberDiff line numberDiff line change
@@ -547,38 +547,7 @@ def __getitem__(self, value):
547547
return self._shallow_copy(left, right)
548548

549549
def __setitem__(self, key, value):
550-
# na value: need special casing to set directly on numpy arrays
551-
needs_float_conversion = False
552-
if is_scalar(value) and isna(value):
553-
if is_integer_dtype(self.dtype.subtype):
554-
# can't set NaN on a numpy integer array
555-
needs_float_conversion = True
556-
elif is_datetime64_any_dtype(self.dtype.subtype):
557-
# need proper NaT to set directly on the numpy array
558-
value = np.datetime64("NaT")
559-
elif is_timedelta64_dtype(self.dtype.subtype):
560-
# need proper NaT to set directly on the numpy array
561-
value = np.timedelta64("NaT")
562-
value_left, value_right = value, value
563-
564-
# scalar interval
565-
elif is_interval_dtype(value) or isinstance(value, Interval):
566-
self._check_closed_matches(value, name="value")
567-
value_left, value_right = value.left, value.right
568-
569-
else:
570-
# list-like of intervals
571-
try:
572-
array = IntervalArray(value)
573-
value_left, value_right = array.left, array.right
574-
except TypeError as err:
575-
# wrong type: not interval or NA
576-
msg = f"'value' should be an interval type, got {type(value)} instead."
577-
raise TypeError(msg) from err
578-
579-
if needs_float_conversion:
580-
raise ValueError("Cannot set float NaN to integer-backed IntervalArray")
581-
550+
value_left, value_right = self._validate_setitem_value(value)
582551
key = check_array_indexer(self, key)
583552

584553
# Need to ensure that left and right are updated atomically, so we're
@@ -898,6 +867,41 @@ def _validate_insert_value(self, value):
898867
)
899868
return left_insert, right_insert
900869

870+
def _validate_setitem_value(self, value):
871+
needs_float_conversion = False
872+
873+
if is_scalar(value) and isna(value):
874+
# na value: need special casing to set directly on numpy arrays
875+
if is_integer_dtype(self.dtype.subtype):
876+
# can't set NaN on a numpy integer array
877+
needs_float_conversion = True
878+
elif is_datetime64_any_dtype(self.dtype.subtype):
879+
# need proper NaT to set directly on the numpy array
880+
value = np.datetime64("NaT")
881+
elif is_timedelta64_dtype(self.dtype.subtype):
882+
# need proper NaT to set directly on the numpy array
883+
value = np.timedelta64("NaT")
884+
value_left, value_right = value, value
885+
886+
elif is_interval_dtype(value) or isinstance(value, Interval):
887+
# scalar interval
888+
self._check_closed_matches(value, name="value")
889+
value_left, value_right = value.left, value.right
890+
891+
else:
892+
try:
893+
# list-like of intervals
894+
array = IntervalArray(value)
895+
value_left, value_right = array.left, array.right
896+
except TypeError as err:
897+
# wrong type: not interval or NA
898+
msg = f"'value' should be an interval type, got {type(value)} instead."
899+
raise TypeError(msg) from err
900+
901+
if needs_float_conversion:
902+
raise ValueError("Cannot set float NaN to integer-backed IntervalArray")
903+
return value_left, value_right
904+
901905
def value_counts(self, dropna=True):
902906
"""
903907
Returns a Series containing counts of each interval.

pandas/core/indexes/multi.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -3596,6 +3596,15 @@ def astype(self, dtype, copy=True):
35963596
return self._shallow_copy()
35973597
return self
35983598

3599+
def _validate_insert_value(self, item):
3600+
if not isinstance(item, tuple):
3601+
# Pad the key with empty strings if lower levels of the key
3602+
# aren't specified:
3603+
item = (item,) + ("",) * (self.nlevels - 1)
3604+
elif len(item) != self.nlevels:
3605+
raise ValueError("Item must have length equal to number of levels.")
3606+
return item
3607+
35993608
def insert(self, loc: int, item):
36003609
"""
36013610
Make new MultiIndex inserting new item at location
@@ -3610,12 +3619,7 @@ def insert(self, loc: int, item):
36103619
-------
36113620
new_index : Index
36123621
"""
3613-
# Pad the key with empty strings if lower levels of the key
3614-
# aren't specified:
3615-
if not isinstance(item, tuple):
3616-
item = (item,) + ("",) * (self.nlevels - 1)
3617-
elif len(item) != self.nlevels:
3618-
raise ValueError("Item must have length equal to number of levels.")
3622+
item = self._validate_insert_value(item)
36193623

36203624
new_levels = []
36213625
new_codes = []

0 commit comments

Comments
 (0)