Skip to content

CLN: ensure we pass correct type to DTI/TDI shallow_copy #37171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ def factorize(

# return original tenor
if isinstance(original, ABCIndexClass):
if original.dtype.kind in ["m", "M"] and isinstance(uniques, np.ndarray):
uniques = type(original._data)._simple_new(uniques, dtype=original.dtype)
uniques = original._shallow_copy(uniques, name=None)
elif isinstance(original, ABCSeries):
from pandas import Index
Expand Down Expand Up @@ -1650,7 +1652,8 @@ def take_nd(
"""
mask_info = None

if is_extension_array_dtype(arr):
if isinstance(arr, ABCExtensionArray):
# Check for EA to catch DatetimeArray, TimedeltaArray
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)

arr = extract_array(arr)
Expand Down Expand Up @@ -2043,7 +2046,7 @@ def safe_sort(
"Only list-like objects are allowed to be passed to safe_sort as values"
)

if not isinstance(values, np.ndarray) and not is_extension_array_dtype(values):
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
# don't convert to string types
dtype, _ = infer_dtype_from_array(values)
values = np.asarray(values, dtype=dtype)
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,11 @@ def _union(self, other, sort):
return self._shallow_copy(result)

def _wrap_setop_result(self, other, result):
if isinstance(self, (ABCDatetimeIndex, ABCTimedeltaIndex)) and isinstance(
result, np.ndarray
):
result = type(self._data)._simple_new(result, dtype=self.dtype)

name = get_op_result_name(self, other)
if isinstance(result, Index):
if result.name != name:
Expand Down Expand Up @@ -2740,10 +2745,10 @@ def intersection(self, other, sort=False):
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
indexer = indexer[indexer != -1]

result = other.take(indexer)
result = other.take(indexer)._values

if sort is None:
result = algos.safe_sort(result.values)
result = algos.safe_sort(result)

return self._wrap_setop_result(other, result)

Expand Down Expand Up @@ -2800,7 +2805,7 @@ def difference(self, other, sort=None):
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
the_diff = this.values.take(label_diff)
the_diff = this._values.take(label_diff)
if sort is None:
try:
the_diff = algos.safe_sort(the_diff)
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,6 @@ def _shallow_copy(self, values=None, name: Label = lib.no_default):
name = self.name if name is lib.no_default else name

if values is not None:
# TODO: We would rather not get here
if isinstance(values, np.ndarray):
values = type(self._data)(values, dtype=self.dtype)
return self._simple_new(values, name=name)

result = self._simple_new(self._data, name=name)
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/indexes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def test_get_unique_index(self, index):
vals[0] = np.nan

vals_unique = vals[:2]
if index.dtype.kind in ["m", "M"]:
# i.e. needs_i8_conversion but not period_dtype, as above
vals = type(index._data)._simple_new(vals, dtype=index.dtype)
vals_unique = type(index._data)._simple_new(vals_unique, dtype=index.dtype)
idx_nan = index._shallow_copy(vals)
idx_unique_nan = index._shallow_copy(vals_unique)
assert idx_unique_nan.is_unique is True
Expand Down