Skip to content

Commit d208c36

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
Standardize cast_str behavior in all datetimelike fill_value validators (pandas-dev#36746)
* Standardize cast_str behavior in all datetimelike fill_value validators * CLN: remove cast_str kwarg
1 parent 5e703ce commit d208c36

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

pandas/core/arrays/datetimelike.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -765,9 +765,7 @@ def _validate_shift_value(self, fill_value):
765765

766766
return self._unbox(fill_value)
767767

768-
def _validate_scalar(
769-
self, value, msg: Optional[str] = None, cast_str: bool = False
770-
):
768+
def _validate_scalar(self, value, msg: Optional[str] = None):
771769
"""
772770
Validate that the input value can be cast to our scalar_type.
773771
@@ -778,14 +776,12 @@ def _validate_scalar(
778776
Message to raise in TypeError on invalid input.
779777
If not provided, `value` is cast to a str and used
780778
as the message.
781-
cast_str : bool, default False
782-
Whether to try to parse string input to scalar_type.
783779
784780
Returns
785781
-------
786782
self._scalar_type or NaT
787783
"""
788-
if cast_str and isinstance(value, str):
784+
if isinstance(value, str):
789785
# NB: Careful about tzawareness
790786
try:
791787
value = self._scalar_from_string(value)
@@ -807,9 +803,7 @@ def _validate_scalar(
807803

808804
return value
809805

810-
def _validate_listlike(
811-
self, value, opname: str, cast_str: bool = False, allow_object: bool = False
812-
):
806+
def _validate_listlike(self, value, opname: str, allow_object: bool = False):
813807
if isinstance(value, type(self)):
814808
return value
815809

@@ -818,7 +812,7 @@ def _validate_listlike(
818812
value = array(value)
819813
value = extract_array(value, extract_numpy=True)
820814

821-
if cast_str and is_dtype_equal(value.dtype, "string"):
815+
if is_dtype_equal(value.dtype, "string"):
822816
# We got a StringArray
823817
try:
824818
# TODO: Could use from_sequence_of_strings if implemented
@@ -848,9 +842,9 @@ def _validate_listlike(
848842
def _validate_searchsorted_value(self, value):
849843
msg = "searchsorted requires compatible dtype or scalar"
850844
if not is_list_like(value):
851-
value = self._validate_scalar(value, msg, cast_str=True)
845+
value = self._validate_scalar(value, msg)
852846
else:
853-
value = self._validate_listlike(value, "searchsorted", cast_str=True)
847+
value = self._validate_listlike(value, "searchsorted")
854848

855849
rv = self._unbox(value)
856850
return self._rebox_native(rv)
@@ -861,15 +855,15 @@ def _validate_setitem_value(self, value):
861855
f"or array of those. Got '{type(value).__name__}' instead."
862856
)
863857
if is_list_like(value):
864-
value = self._validate_listlike(value, "setitem", cast_str=True)
858+
value = self._validate_listlike(value, "setitem")
865859
else:
866-
value = self._validate_scalar(value, msg, cast_str=True)
860+
value = self._validate_scalar(value, msg)
867861

868862
return self._unbox(value, setitem=True)
869863

870864
def _validate_insert_value(self, value):
871865
msg = f"cannot insert {type(self).__name__} with incompatible label"
872-
value = self._validate_scalar(value, msg, cast_str=False)
866+
value = self._validate_scalar(value, msg)
873867

874868
self._check_compatible_with(value, setitem=True)
875869
# TODO: if we dont have compat, should we raise or astype(object)?

pandas/core/indexes/datetimelike.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def _wrap_joined_index(self, joined: np.ndarray, other):
646646
def _convert_arr_indexer(self, keyarr):
647647
try:
648648
return self._data._validate_listlike(
649-
keyarr, "convert_arr_indexer", cast_str=True, allow_object=True
649+
keyarr, "convert_arr_indexer", allow_object=True
650650
)
651651
except (ValueError, TypeError):
652652
return com.asarray_tuplesafe(keyarr)

pandas/core/indexes/timedeltas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_loc(self, key, method=None, tolerance=None):
217217
raise InvalidIndexError(key)
218218

219219
try:
220-
key = self._data._validate_scalar(key, cast_str=True)
220+
key = self._data._validate_scalar(key)
221221
except TypeError as err:
222222
raise KeyError(key) from err
223223

pandas/tests/arrays/test_datetimelike.py

+10
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ def test_take_fill(self):
160160
result = arr.take([-1, 1], allow_fill=True, fill_value=pd.NaT)
161161
assert result[0] is pd.NaT
162162

163+
def test_take_fill_str(self, arr1d):
164+
# Cast str fill_value matching other fill_value-taking methods
165+
result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1]))
166+
expected = arr1d[[-1, 1]]
167+
tm.assert_equal(result, expected)
168+
169+
msg = r"'fill_value' should be a <.*>\. Got 'foo'"
170+
with pytest.raises(ValueError, match=msg):
171+
arr1d.take([-1, 1], allow_fill=True, fill_value="foo")
172+
163173
def test_concat_same_type(self):
164174
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
165175

pandas/tests/indexes/datetimelike.py

+21
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,24 @@ def test_not_equals_numeric(self):
115115
assert not index.equals(pd.Index(index.asi8))
116116
assert not index.equals(pd.Index(index.asi8.astype("u8")))
117117
assert not index.equals(pd.Index(index.asi8).astype("f8"))
118+
119+
def test_where_cast_str(self):
120+
index = self.create_index()
121+
122+
mask = np.ones(len(index), dtype=bool)
123+
mask[-1] = False
124+
125+
result = index.where(mask, str(index[0]))
126+
expected = index.where(mask, index[0])
127+
tm.assert_index_equal(result, expected)
128+
129+
result = index.where(mask, [str(index[0])])
130+
tm.assert_index_equal(result, expected)
131+
132+
msg = "Where requires matching dtype, not foo"
133+
with pytest.raises(TypeError, match=msg):
134+
index.where(mask, "foo")
135+
136+
msg = r"Where requires matching dtype, not \['foo'\]"
137+
with pytest.raises(TypeError, match=msg):
138+
index.where(mask, ["foo"])

0 commit comments

Comments
 (0)