Skip to content

Commit 28a0f66

Browse files
authored
BUG: Index.where casting ints to str (pandas-dev#37591)
1 parent 891b16f commit 28a0f66

File tree

10 files changed

+49
-44
lines changed

10 files changed

+49
-44
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ Indexing
456456
- Bug in :meth:`Series.loc.__getitem__` with a non-unique :class:`MultiIndex` and an empty-list indexer (:issue:`13691`)
457457
- Bug in indexing on a :class:`Series` or :class:`DataFrame` with a :class:`MultiIndex` with a level named "0" (:issue:`37194`)
458458
- Bug in :meth:`Series.__getitem__` when using an unsigned integer array as an indexer giving incorrect results or segfaulting instead of raising ``KeyError`` (:issue:`37218`)
459+
- Bug in :meth:`Index.where` incorrectly casting numeric values to strings (:issue:`37591`)
459460

460461
Missing
461462
^^^^^^^

pandas/core/indexes/base.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
ensure_int64,
4141
ensure_object,
4242
ensure_platform_int,
43-
is_bool,
4443
is_bool_dtype,
4544
is_categorical_dtype,
4645
is_datetime64_any_dtype,
@@ -4079,23 +4078,16 @@ def where(self, cond, other=None):
40794078
if other is None:
40804079
other = self._na_value
40814080

4082-
dtype = self.dtype
40834081
values = self.values
40844082

4085-
if is_bool(other) or is_bool_dtype(other):
4086-
4087-
# bools force casting
4088-
values = values.astype(object)
4089-
dtype = None
4083+
try:
4084+
self._validate_fill_value(other)
4085+
except (ValueError, TypeError):
4086+
return self.astype(object).where(cond, other)
40904087

40914088
values = np.where(cond, values, other)
40924089

4093-
if self._is_numeric_dtype and np.any(isna(values)):
4094-
# We can't coerce to the numeric dtype of "self" (unless
4095-
# it's float) if there are NaN values in our output.
4096-
dtype = None
4097-
4098-
return Index(values, dtype=dtype, name=self.name)
4090+
return Index(values, name=self.name)
40994091

41004092
# construction helpers
41014093
@final

pandas/core/indexes/datetimelike.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,9 @@ def isin(self, values, level=None):
482482

483483
@Appender(Index.where.__doc__)
484484
def where(self, cond, other=None):
485-
values = self._data._ndarray
485+
other = self._data._validate_setitem_value(other)
486486

487-
try:
488-
other = self._data._validate_setitem_value(other)
489-
except (TypeError, ValueError) as err:
490-
# Includes tzawareness mismatch and IncompatibleFrequencyError
491-
oth = getattr(other, "dtype", other)
492-
raise TypeError(f"Where requires matching dtype, not {oth}") from err
493-
494-
result = np.where(cond, values, other)
487+
result = np.where(cond, self._data._ndarray, other)
495488
arr = self._data._from_backing_data(result)
496489
return type(self)._simple_new(arr, name=self.name)
497490

pandas/core/indexes/numeric.py

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def _validate_fill_value(self, value):
121121
# force conversion to object
122122
# so we don't lose the bools
123123
raise TypeError
124+
if isinstance(value, str):
125+
raise TypeError
124126

125127
return value
126128

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
3+
from pandas import Index
4+
import pandas._testing as tm
5+
6+
7+
class TestWhere:
8+
def test_where_intlike_str_doesnt_cast_ints(self):
9+
idx = Index(range(3))
10+
mask = np.array([True, False, True])
11+
res = idx.where(mask, "2")
12+
expected = Index([0, "2", 2])
13+
tm.assert_index_equal(res, expected)

pandas/tests/indexes/datetimelike.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,9 @@ def test_where_cast_str(self):
143143
result = index.where(mask, [str(index[0])])
144144
tm.assert_index_equal(result, expected)
145145

146-
msg = "Where requires matching dtype, not foo"
146+
msg = "value should be a '.*', 'NaT', or array of those"
147147
with pytest.raises(TypeError, match=msg):
148148
index.where(mask, "foo")
149149

150-
msg = r"Where requires matching dtype, not \['foo'\]"
151150
with pytest.raises(TypeError, match=msg):
152151
index.where(mask, ["foo"])

pandas/tests/indexes/datetimes/test_indexing.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,26 @@ def test_where_invalid_dtypes(self):
177177

178178
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())
179179

180-
with pytest.raises(TypeError, match="Where requires matching dtype"):
180+
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
181+
msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
182+
with pytest.raises(TypeError, match=msg2):
181183
# passing tz-naive ndarray to tzaware DTI
182184
dti.where(notna(i2), i2.values)
183185

184-
with pytest.raises(TypeError, match="Where requires matching dtype"):
186+
with pytest.raises(TypeError, match=msg2):
185187
# passing tz-aware DTI to tznaive DTI
186188
dti.tz_localize(None).where(notna(i2), i2)
187189

188-
with pytest.raises(TypeError, match="Where requires matching dtype"):
190+
with pytest.raises(TypeError, match=msg):
189191
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))
190192

191-
with pytest.raises(TypeError, match="Where requires matching dtype"):
193+
with pytest.raises(TypeError, match=msg):
192194
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
193195

194-
with pytest.raises(TypeError, match="Where requires matching dtype"):
196+
with pytest.raises(TypeError, match=msg):
195197
dti.where(notna(i2), i2.asi8)
196198

197-
with pytest.raises(TypeError, match="Where requires matching dtype"):
199+
with pytest.raises(TypeError, match=msg):
198200
# non-matching scalar
199201
dti.where(notna(i2), pd.Timedelta(days=4))
200202

@@ -203,7 +205,7 @@ def test_where_mismatched_nat(self, tz_aware_fixture):
203205
dti = pd.date_range("2013-01-01", periods=3, tz=tz)
204206
cond = np.array([True, False, True])
205207

206-
msg = "Where requires matching dtype"
208+
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
207209
with pytest.raises(TypeError, match=msg):
208210
# wrong-dtyped NaT
209211
dti.where(cond, np.timedelta64("NaT", "ns"))

pandas/tests/indexes/period/test_indexing.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -545,24 +545,25 @@ def test_where_invalid_dtypes(self):
545545

546546
i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")
547547

548-
with pytest.raises(TypeError, match="Where requires matching dtype"):
548+
msg = "value should be a 'Period', 'NaT', or array of those"
549+
with pytest.raises(TypeError, match=msg):
549550
pi.where(notna(i2), i2.asi8)
550551

551-
with pytest.raises(TypeError, match="Where requires matching dtype"):
552+
with pytest.raises(TypeError, match=msg):
552553
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
553554

554-
with pytest.raises(TypeError, match="Where requires matching dtype"):
555+
with pytest.raises(TypeError, match=msg):
555556
pi.where(notna(i2), i2.to_timestamp("S"))
556557

557-
with pytest.raises(TypeError, match="Where requires matching dtype"):
558+
with pytest.raises(TypeError, match=msg):
558559
# non-matching scalar
559560
pi.where(notna(i2), Timedelta(days=4))
560561

561562
def test_where_mismatched_nat(self):
562563
pi = period_range("20130101", periods=5, freq="D")
563564
cond = np.array([True, False, True, True, False])
564565

565-
msg = "Where requires matching dtype"
566+
msg = "value should be a 'Period', 'NaT', or array of those"
566567
with pytest.raises(TypeError, match=msg):
567568
# wrong-dtyped NaT
568569
pi.where(cond, np.timedelta64("NaT", "ns"))

pandas/tests/indexes/timedeltas/test_indexing.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -150,24 +150,25 @@ def test_where_invalid_dtypes(self):
150150

151151
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())
152152

153-
with pytest.raises(TypeError, match="Where requires matching dtype"):
153+
msg = "value should be a 'Timedelta', 'NaT', or array of those"
154+
with pytest.raises(TypeError, match=msg):
154155
tdi.where(notna(i2), i2.asi8)
155156

156-
with pytest.raises(TypeError, match="Where requires matching dtype"):
157+
with pytest.raises(TypeError, match=msg):
157158
tdi.where(notna(i2), i2 + pd.Timestamp.now())
158159

159-
with pytest.raises(TypeError, match="Where requires matching dtype"):
160+
with pytest.raises(TypeError, match=msg):
160161
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))
161162

162-
with pytest.raises(TypeError, match="Where requires matching dtype"):
163+
with pytest.raises(TypeError, match=msg):
163164
# non-matching scalar
164165
tdi.where(notna(i2), pd.Timestamp.now())
165166

166167
def test_where_mismatched_nat(self):
167168
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
168169
cond = np.array([True, False, False])
169170

170-
msg = "Where requires matching dtype"
171+
msg = "value should be a 'Timedelta', 'NaT', or array of those"
171172
with pytest.raises(TypeError, match=msg):
172173
# wrong-dtyped NaT
173174
tdi.where(cond, np.datetime64("NaT", "ns"))

pandas/tests/indexing/test_coercion.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def test_where_index_timedelta64(self, value):
780780
result = tdi.where(cond, value)
781781
tm.assert_index_equal(result, expected)
782782

783-
msg = "Where requires matching dtype"
783+
msg = "value should be a 'Timedelta', 'NaT', or array of thos"
784784
with pytest.raises(TypeError, match=msg):
785785
# wrong-dtyped NaT
786786
tdi.where(cond, np.datetime64("NaT", "ns"))
@@ -804,11 +804,12 @@ def test_where_index_period(self):
804804
tm.assert_index_equal(result, expected)
805805

806806
# Passing a mismatched scalar
807-
msg = "Where requires matching dtype"
807+
msg = "value should be a 'Period', 'NaT', or array of those"
808808
with pytest.raises(TypeError, match=msg):
809809
pi.where(cond, pd.Timedelta(days=4))
810810

811-
with pytest.raises(TypeError, match=msg):
811+
msg = r"Input has different freq=D from PeriodArray\(freq=Q-DEC\)"
812+
with pytest.raises(ValueError, match=msg):
812813
pi.where(cond, pd.Period("2020-04-21", "D"))
813814

814815

0 commit comments

Comments
 (0)