From 134092ae59cea98244506be72f7822200874f70c Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 16 Oct 2020 14:09:39 -0700 Subject: [PATCH] CLN: ensure we pass correct type to DTI/TDI shallow_copy --- pandas/core/algorithms.py | 7 +++++-- pandas/core/indexes/base.py | 11 ++++++++--- pandas/core/indexes/datetimelike.py | 3 --- pandas/tests/indexes/test_common.py | 4 ++++ 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index d2005d46bbbf1..4e07a3e0c6df8 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -696,6 +696,8 @@ def factorize( # return original tenor if isinstance(original, ABCIndexClass): + if original.dtype.kind in ["m", "M"] and isinstance(uniques, np.ndarray): + uniques = type(original._data)._simple_new(uniques, dtype=original.dtype) uniques = original._shallow_copy(uniques, name=None) elif isinstance(original, ABCSeries): from pandas import Index @@ -1650,7 +1652,8 @@ def take_nd( """ mask_info = None - if is_extension_array_dtype(arr): + if isinstance(arr, ABCExtensionArray): + # Check for EA to catch DatetimeArray, TimedeltaArray return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill) arr = extract_array(arr) @@ -2043,7 +2046,7 @@ def safe_sort( "Only list-like objects are allowed to be passed to safe_sort as values" ) - if not isinstance(values, np.ndarray) and not is_extension_array_dtype(values): + if not isinstance(values, (np.ndarray, ABCExtensionArray)): # don't convert to string types dtype, _ = infer_dtype_from_array(values) values = np.asarray(values, dtype=dtype) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 87dd15d5b142b..2ebf2389823e9 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2664,6 +2664,11 @@ def _union(self, other, sort): return self._shallow_copy(result) def _wrap_setop_result(self, other, result): + if isinstance(self, (ABCDatetimeIndex, ABCTimedeltaIndex)) and isinstance( + result, np.ndarray + ): + result = type(self._data)._simple_new(result, dtype=self.dtype) + name = get_op_result_name(self, other) if isinstance(result, Index): if result.name != name: @@ -2740,10 +2745,10 @@ def intersection(self, other, sort=False): indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0]) indexer = indexer[indexer != -1] - result = other.take(indexer) + result = other.take(indexer)._values if sort is None: - result = algos.safe_sort(result.values) + result = algos.safe_sort(result) return self._wrap_setop_result(other, result) @@ -2800,7 +2805,7 @@ def difference(self, other, sort=None): indexer = indexer.take((indexer != -1).nonzero()[0]) label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True) - the_diff = this.values.take(label_diff) + the_diff = this._values.take(label_diff) if sort is None: try: the_diff = algos.safe_sort(the_diff) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 3845a601000a0..017dc6527944a 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -696,9 +696,6 @@ def _shallow_copy(self, values=None, name: Label = lib.no_default): name = self.name if name is lib.no_default else name if values is not None: - # TODO: We would rather not get here - if isinstance(values, np.ndarray): - values = type(self._data)(values, dtype=self.dtype) return self._simple_new(values, name=name) result = self._simple_new(self._data, name=name) diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index 94b10572fb5e1..6a681ede8ff42 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -320,6 +320,10 @@ def test_get_unique_index(self, index): vals[0] = np.nan vals_unique = vals[:2] + if index.dtype.kind in ["m", "M"]: + # i.e. needs_i8_conversion but not period_dtype, as above + vals = type(index._data)._simple_new(vals, dtype=index.dtype) + vals_unique = type(index._data)._simple_new(vals_unique, dtype=index.dtype) idx_nan = index._shallow_copy(vals) idx_unique_nan = index._shallow_copy(vals_unique) assert idx_unique_nan.is_unique is True