From 12739ebdb1dc44fd2bb452538145c3f6769a7b3d Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 18 Sep 2020 16:23:47 -0700 Subject: [PATCH 1/2] REF: MultiIndex._validate_insert_value --- pandas/core/indexes/multi.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index a21a54e4a9be3..cd3e384837280 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3596,6 +3596,15 @@ def astype(self, dtype, copy=True): return self._shallow_copy() return self + def _validate_insert_value(self, item): + if not isinstance(item, tuple): + # Pad the key with empty strings if lower levels of the key + # aren't specified: + item = (item,) + ("",) * (self.nlevels - 1) + elif len(item) != self.nlevels: + raise ValueError("Item must have length equal to number of levels.") + return item + def insert(self, loc: int, item): """ Make new MultiIndex inserting new item at location @@ -3610,12 +3619,7 @@ def insert(self, loc: int, item): ------- new_index : Index """ - # Pad the key with empty strings if lower levels of the key - # aren't specified: - if not isinstance(item, tuple): - item = (item,) + ("",) * (self.nlevels - 1) - elif len(item) != self.nlevels: - raise ValueError("Item must have length equal to number of levels.") + item = self._validate_insert_value(item) new_levels = [] new_codes = [] From 6902508db7711409a461d9f0bc42e8dce9299e72 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 18 Sep 2020 16:26:32 -0700 Subject: [PATCH 2/2] REF: IntervalArray._validate_setitem_value --- pandas/core/arrays/interval.py | 68 ++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index ff9dd3f2a85bc..f9f68004bcc23 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -547,38 +547,7 @@ def __getitem__(self, value): return self._shallow_copy(left, right) def __setitem__(self, key, value): - # na value: need special casing to set directly on numpy arrays - needs_float_conversion = False - if is_scalar(value) and isna(value): - if is_integer_dtype(self.dtype.subtype): - # can't set NaN on a numpy integer array - needs_float_conversion = True - elif is_datetime64_any_dtype(self.dtype.subtype): - # need proper NaT to set directly on the numpy array - value = np.datetime64("NaT") - elif is_timedelta64_dtype(self.dtype.subtype): - # need proper NaT to set directly on the numpy array - value = np.timedelta64("NaT") - value_left, value_right = value, value - - # scalar interval - elif is_interval_dtype(value) or isinstance(value, Interval): - self._check_closed_matches(value, name="value") - value_left, value_right = value.left, value.right - - else: - # list-like of intervals - try: - array = IntervalArray(value) - value_left, value_right = array.left, array.right - except TypeError as err: - # wrong type: not interval or NA - msg = f"'value' should be an interval type, got {type(value)} instead." - raise TypeError(msg) from err - - if needs_float_conversion: - raise ValueError("Cannot set float NaN to integer-backed IntervalArray") - + value_left, value_right = self._validate_setitem_value(value) key = check_array_indexer(self, key) # Need to ensure that left and right are updated atomically, so we're @@ -898,6 +867,41 @@ def _validate_insert_value(self, value): ) return left_insert, right_insert + def _validate_setitem_value(self, value): + needs_float_conversion = False + + if is_scalar(value) and isna(value): + # na value: need special casing to set directly on numpy arrays + if is_integer_dtype(self.dtype.subtype): + # can't set NaN on a numpy integer array + needs_float_conversion = True + elif is_datetime64_any_dtype(self.dtype.subtype): + # need proper NaT to set directly on the numpy array + value = np.datetime64("NaT") + elif is_timedelta64_dtype(self.dtype.subtype): + # need proper NaT to set directly on the numpy array + value = np.timedelta64("NaT") + value_left, value_right = value, value + + elif is_interval_dtype(value) or isinstance(value, Interval): + # scalar interval + self._check_closed_matches(value, name="value") + value_left, value_right = value.left, value.right + + else: + try: + # list-like of intervals + array = IntervalArray(value) + value_left, value_right = array.left, array.right + except TypeError as err: + # wrong type: not interval or NA + msg = f"'value' should be an interval type, got {type(value)} instead." + raise TypeError(msg) from err + + if needs_float_conversion: + raise ValueError("Cannot set float NaN to integer-backed IntervalArray") + return value_left, value_right + def value_counts(self, dropna=True): """ Returns a Series containing counts of each interval.