Skip to content

Commit 033ac9c

Browse files
committed
Setitem-based where
1 parent e9665b8 commit 033ac9c

File tree

10 files changed

+174
-156
lines changed

10 files changed

+174
-156
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,7 @@ Datetimelike
13101310
- Bug in :class:`DatetimeIndex` where calling ``np.array(dtindex, dtype=object)`` would incorrectly return an array of ``long`` objects (:issue:`23524`)
13111311
- Bug in :class:`Index` where passing a timezone-aware :class:`DatetimeIndex` and `dtype=object` would incorrectly raise a ``ValueError`` (:issue:`23524`)
13121312
- Bug in :class:`Index` where calling ``np.array(dtindex, dtype=object)`` on a timezone-naive :class:`DatetimeIndex` would return an array of ``datetime`` objects instead of :class:`Timestamp` objects, potentially losing nanosecond portions of the timestamps (:issue:`23524`)
1313+
- Bug in :class:`Categorical.__setitem__` not allowing setting with another ``Categorical`` when both are undordered and have the same categories, but in a different order (:issue:`24142`)
13131314

13141315
Timedelta
13151316
^^^^^^^^^

pandas/core/arrays/base.py

-36
Original file line numberDiff line numberDiff line change
@@ -662,42 +662,6 @@ def take(self, indices, allow_fill=False, fill_value=None):
662662
# pandas.api.extensions.take
663663
raise AbstractMethodError(self)
664664

665-
def where(self, cond, other):
666-
"""
667-
Replace values where the condition is False.
668-
669-
.. versionadded :: 0.24.0
670-
671-
Parameters
672-
----------
673-
cond : ndarray or ExtensionArray
674-
The mask indicating which values should be kept (True)
675-
or replaced from `other` (False).
676-
677-
other : ndarray, ExtensionArray, or scalar
678-
Entries where `cond` is False are replaced with
679-
corresponding value from `other`.
680-
681-
Notes
682-
-----
683-
Note that `cond` and `other` *cannot* be a Series, Index, or callable.
684-
When used from, e.g., :meth:`Series.where`, pandas will unbox
685-
Series and Indexes, and will apply callables before they arrive here.
686-
687-
Returns
688-
-------
689-
ExtensionArray
690-
Same dtype as the original.
691-
692-
See Also
693-
--------
694-
Series.where : Similar method for Series.
695-
DataFrame.where : Similar method for DataFrame.
696-
"""
697-
return type(self)._from_sequence(np.where(cond, self, other),
698-
dtype=self.dtype,
699-
copy=False)
700-
701665
def copy(self, deep=False):
702666
# type: (bool) -> ExtensionArray
703667
"""

pandas/core/arrays/categorical.py

+11-44
Original file line numberDiff line numberDiff line change
@@ -1906,49 +1906,6 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None):
19061906

19071907
take = take_nd
19081908

1909-
def where(self, cond, other):
1910-
# n.b. this now preserves the type
1911-
codes = self._codes
1912-
object_msg = (
1913-
"Implicitly converting categorical to object-dtype ndarray. "
1914-
"The values `{}' are not present in this categorical's "
1915-
"categories. A future version of pandas will raise a ValueError "
1916-
"when 'other' contains different categories.\n\n"
1917-
"To preserve the current behavior, add the new categories to "
1918-
"the categorical before calling 'where', or convert the "
1919-
"categorical to a different dtype."
1920-
)
1921-
1922-
if is_scalar(other) and isna(other):
1923-
other = -1
1924-
elif is_scalar(other):
1925-
item = self.categories.get_indexer([other]).item()
1926-
1927-
if item == -1:
1928-
# note: when removing this, also remove CategoricalBlock.where
1929-
warn(object_msg.format(other), FutureWarning, stacklevel=2)
1930-
return np.where(cond, self, other)
1931-
1932-
other = item
1933-
1934-
elif is_categorical_dtype(other):
1935-
if not is_dtype_equal(self, other):
1936-
extra = list(other.categories.difference(self.categories))
1937-
warn(object_msg.format(compat.reprlib.repr(extra)),
1938-
FutureWarning,
1939-
stacklevel=2)
1940-
return np.where(cond, self, other)
1941-
other = _get_codes_for_values(other, self.categories)
1942-
# get the codes from other that match our categories
1943-
pass
1944-
else:
1945-
other = np.where(isna(other), -1, other)
1946-
1947-
new_codes = np.where(cond, codes, other)
1948-
return type(self).from_codes(new_codes,
1949-
categories=self.categories,
1950-
ordered=self.ordered)
1951-
19521909
def _slice(self, slicer):
19531910
"""
19541911
Return a slice of myself.
@@ -2121,11 +2078,21 @@ def __setitem__(self, key, value):
21212078
`Categorical` does not have the same categories
21222079
"""
21232080

2081+
if isinstance(value, (ABCIndexClass, ABCSeries)):
2082+
value = value.array
2083+
21242084
# require identical categories set
21252085
if isinstance(value, Categorical):
2126-
if not value.categories.equals(self.categories):
2086+
if not is_dtype_equal(self, value):
21272087
raise ValueError("Cannot set a Categorical with another, "
21282088
"without identical categories")
2089+
if not self.categories.equals(value.categories):
2090+
new_codes = _recode_for_categories(
2091+
value.codes, value.categories, self.categories
2092+
)
2093+
value = Categorical.from_codes(new_codes,
2094+
categories=self.categories,
2095+
ordered=self.ordered)
21292096

21302097
rvalue = value if is_list_like(value) else [value]
21312098

pandas/core/arrays/interval.py

-11
Original file line numberDiff line numberDiff line change
@@ -777,17 +777,6 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None,
777777

778778
return self._shallow_copy(left_take, right_take)
779779

780-
def where(self, cond, other):
781-
if is_scalar(other) and isna(other):
782-
lother = rother = other
783-
else:
784-
self._check_closed_matches(other, name='other')
785-
lother = other.left
786-
rother = other.right
787-
left = np.where(cond, self.left, lother)
788-
right = np.where(cond, self.right, rother)
789-
return self._shallow_copy(left, right)
790-
791780
def value_counts(self, dropna=True):
792781
"""
793782
Returns a Series containing counts of each interval.

pandas/core/arrays/period.py

-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66

7-
from pandas._libs import lib
87
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
98
from pandas._libs.tslibs.fields import isleapyear_arr
109
from pandas._libs.tslibs.period import (
@@ -347,22 +346,6 @@ def to_timestamp(self, freq=None, how='start'):
347346
# --------------------------------------------------------------------
348347
# Array-like / EA-Interface Methods
349348

350-
def where(self, cond, other):
351-
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
352-
# n.b. _ndarray_values candidate.
353-
i8 = self.asi8
354-
if lib.is_scalar(other):
355-
if isna(other):
356-
other = iNaT
357-
elif isinstance(other, Period):
358-
self._check_compatible_with(other)
359-
other = other.ordinal
360-
elif isinstance(other, type(self)):
361-
self._check_compatible_with(other)
362-
other = other.asi8
363-
result = np.where(cond, i8, other)
364-
return type(self)._simple_new(result, dtype=self.dtype)
365-
366349
def _formatter(self, boxed=False):
367350
if boxed:
368351
return str

pandas/core/arrays/sparse.py

-19
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,6 @@ def __array__(self, dtype=None, copy=True):
704704
out[self.sp_index.to_int_index().indices] = self.sp_values
705705
return out
706706

707-
def __setitem__(self, key, value):
708-
# I suppose we could allow setting of non-fill_value elements.
709-
msg = "SparseArray does not support item assignment via setitem"
710-
raise TypeError(msg)
711-
712707
@classmethod
713708
def _from_sequence(cls, scalars, dtype=None, copy=False):
714709
return cls(scalars, dtype=dtype)
@@ -1063,20 +1058,6 @@ def take(self, indices, allow_fill=False, fill_value=None):
10631058
return type(self)(result, fill_value=self.fill_value, kind=self.kind,
10641059
**kwargs)
10651060

1066-
def where(self, cond, other):
1067-
if is_scalar(other):
1068-
result_dtype = np.result_type(self.dtype.subtype, other)
1069-
elif isinstance(other, type(self)):
1070-
result_dtype = np.result_type(self.dtype.subtype,
1071-
other.dtype.subtype)
1072-
else:
1073-
result_dtype = np.result_type(self.dtype.subtype, other.dtype)
1074-
1075-
dtype = self.dtype.update_dtype(result_dtype)
1076-
# TODO: avoid converting to dense.
1077-
values = np.where(cond, self, other)
1078-
return type(self)(values, dtype=dtype)
1079-
10801061
def _take_with_fill(self, indices, fill_value=None):
10811062
if fill_value is None:
10821063
fill_value = self.dtype.na_value

pandas/core/internals/blocks.py

+76-7
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,7 @@ def shift(self, periods, axis=0):
19701970

19711971
def where(self, other, cond, align=True, errors='raise',
19721972
try_cast=False, axis=0, transpose=False):
1973+
# rough attempt to see if
19731974
if isinstance(other, (ABCIndexClass, ABCSeries)):
19741975
other = other.array
19751976

@@ -1989,7 +1990,33 @@ def where(self, other, cond, align=True, errors='raise',
19891990
# we want to replace that with the correct NA value
19901991
# for the type
19911992
other = self.dtype.na_value
1992-
result = self.values.where(cond, other)
1993+
1994+
if is_sparse(self.values):
1995+
# ugly workaround for ensure that the dtype is OK
1996+
# after we insert NaNs.
1997+
if is_sparse(other):
1998+
otype = other.dtype.subtype
1999+
else:
2000+
otype = other
2001+
dtype = self.dtype.update_dtype(
2002+
np.result_type(self.values.dtype.subtype, otype)
2003+
)
2004+
else:
2005+
dtype = self.dtype
2006+
2007+
if self._holder.__setitem__ is ExtensionArray.__setitem__:
2008+
# the array doesn't implement setitem, so convert to ndarray
2009+
result = self._holder._from_sequence(
2010+
np.where(cond, self.values, other),
2011+
dtype=dtype,
2012+
)
2013+
else:
2014+
result = self.values.copy()
2015+
icond = ~cond
2016+
if lib.is_scalar(other):
2017+
result[icond] = other
2018+
else:
2019+
result[icond] = other[icond]
19932020
return self.make_block_same_class(result, placement=self.mgr_locs)
19942021

19952022
@property
@@ -2673,13 +2700,55 @@ def concat_same_type(self, to_concat, placement=None):
26732700

26742701
def where(self, other, cond, align=True, errors='raise',
26752702
try_cast=False, axis=0, transpose=False):
2676-
result = super(CategoricalBlock, self).where(
2677-
other, cond, align, errors, try_cast, axis, transpose
2703+
# This can all be deleted in favor of ExtensionBlock.where once
2704+
# we enforce the deprecation.
2705+
object_msg = (
2706+
"Implicitly converting categorical to object-dtype ndarray. "
2707+
"The values `{}' are not present in this categorical's "
2708+
"categories. A future version of pandas will raise a ValueError "
2709+
"when 'other' contains different categories.\n\n"
2710+
"To preserve the current behavior, add the new categories to "
2711+
"the categorical before calling 'where', or convert the "
2712+
"categorical to a different dtype."
26782713
)
2679-
if result.values.dtype != self.values.dtype:
2680-
# For backwards compatability, we allow upcasting to object.
2681-
# This fallback will be removed in the future.
2682-
result = result.astype(object)
2714+
2715+
scalar_other = lib.is_scalar(other)
2716+
categorical_other = is_categorical_dtype(other)
2717+
if isinstance(other, ABCDataFrame):
2718+
# should be 1d
2719+
assert other.shape[1] == 1
2720+
other = other.iloc[:, 0]
2721+
2722+
if isinstance(other, (ABCSeries, ABCIndexClass)):
2723+
other = other._values
2724+
2725+
do_as_object = (
2726+
# Two categoricals with different dtype (ignoring order)
2727+
(categorical_other and not is_dtype_equal(self.values, other)) or
2728+
# a not-na scalar not present in our categories
2729+
(scalar_other and (other not in self.values.categories
2730+
and notna(other))) or
2731+
# an array not present in our categories
2732+
(not scalar_other and
2733+
(self.values.categories.get_indexer(
2734+
other[notna(other)]) < 0).any())
2735+
)
2736+
2737+
if do_as_object:
2738+
if scalar_other:
2739+
msg = object_msg.format(other)
2740+
else:
2741+
msg = compat.reprlib.repr(other)
2742+
2743+
warnings.warn(msg, FutureWarning, stacklevel=6)
2744+
result = self.astype(object).where(other, cond, align=align,
2745+
errors=errors,
2746+
try_cast=try_cast,
2747+
axis=axis, transpose=transpose)
2748+
else:
2749+
result = super(CategoricalBlock, self).where(
2750+
other, cond, align, errors, try_cast, axis, transpose
2751+
)
26832752
return result
26842753

26852754

0 commit comments

Comments
 (0)