diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 83a9c0ba61c2d..ef402fce642b9 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -756,9 +756,7 @@ def _validate_shift_value(self, fill_value): return self._unbox(fill_value) - def _validate_scalar( - self, value, msg: Optional[str] = None, cast_str: bool = False - ): + def _validate_scalar(self, value, msg: Optional[str] = None): """ Validate that the input value can be cast to our scalar_type. @@ -769,14 +767,12 @@ def _validate_scalar( Message to raise in TypeError on invalid input. If not provided, `value` is cast to a str and used as the message. - cast_str : bool, default False - Whether to try to parse string input to scalar_type. Returns ------- self._scalar_type or NaT """ - if cast_str and isinstance(value, str): + if isinstance(value, str): # NB: Careful about tzawareness try: value = self._scalar_from_string(value) @@ -798,9 +794,7 @@ def _validate_scalar( return value - def _validate_listlike( - self, value, opname: str, cast_str: bool = False, allow_object: bool = False - ): + def _validate_listlike(self, value, opname: str, allow_object: bool = False): if isinstance(value, type(self)): return value @@ -809,7 +803,7 @@ def _validate_listlike( value = array(value) value = extract_array(value, extract_numpy=True) - if cast_str and is_dtype_equal(value.dtype, "string"): + if is_dtype_equal(value.dtype, "string"): # We got a StringArray try: # TODO: Could use from_sequence_of_strings if implemented @@ -839,9 +833,9 @@ def _validate_listlike( def _validate_searchsorted_value(self, value): msg = "searchsorted requires compatible dtype or scalar" if not is_list_like(value): - value = self._validate_scalar(value, msg, cast_str=True) + value = self._validate_scalar(value, msg) else: - value = self._validate_listlike(value, "searchsorted", cast_str=True) + value = self._validate_listlike(value, "searchsorted") rv = self._unbox(value) return self._rebox_native(rv) @@ -852,15 +846,15 @@ def _validate_setitem_value(self, value): f"or array of those. Got '{type(value).__name__}' instead." ) if is_list_like(value): - value = self._validate_listlike(value, "setitem", cast_str=True) + value = self._validate_listlike(value, "setitem") else: - value = self._validate_scalar(value, msg, cast_str=True) + value = self._validate_scalar(value, msg) return self._unbox(value, setitem=True) def _validate_insert_value(self, value): msg = f"cannot insert {type(self).__name__} with incompatible label" - value = self._validate_scalar(value, msg, cast_str=False) + value = self._validate_scalar(value, msg) self._check_compatible_with(value, setitem=True) # TODO: if we dont have compat, should we raise or astype(object)? diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index d2162d987ccd6..5128c644e6bcb 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -646,7 +646,7 @@ def _wrap_joined_index(self, joined: np.ndarray, other): def _convert_arr_indexer(self, keyarr): try: return self._data._validate_listlike( - keyarr, "convert_arr_indexer", cast_str=True, allow_object=True + keyarr, "convert_arr_indexer", allow_object=True ) except (ValueError, TypeError): return com.asarray_tuplesafe(keyarr) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 854c4e33eca01..7e635e55288e5 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -217,7 +217,7 @@ def get_loc(self, key, method=None, tolerance=None): raise InvalidIndexError(key) try: - key = self._data._validate_scalar(key, cast_str=True) + key = self._data._validate_scalar(key) except TypeError as err: raise KeyError(key) from err diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 3f5ab5baa7d69..91bcdf32603f4 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -160,6 +160,16 @@ def test_take_fill(self): result = arr.take([-1, 1], allow_fill=True, fill_value=pd.NaT) assert result[0] is pd.NaT + def test_take_fill_str(self, arr1d): + # Cast str fill_value matching other fill_value-taking methods + result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1])) + expected = arr1d[[-1, 1]] + tm.assert_equal(result, expected) + + msg = r"'fill_value' should be a <.*>\. Got 'foo'" + with pytest.raises(ValueError, match=msg): + arr1d.take([-1, 1], allow_fill=True, fill_value="foo") + def test_concat_same_type(self): data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 diff --git a/pandas/tests/indexes/datetimelike.py b/pandas/tests/indexes/datetimelike.py index 71ae1d6bda9c7..df857cce05bbb 100644 --- a/pandas/tests/indexes/datetimelike.py +++ b/pandas/tests/indexes/datetimelike.py @@ -115,3 +115,24 @@ def test_not_equals_numeric(self): assert not index.equals(pd.Index(index.asi8)) assert not index.equals(pd.Index(index.asi8.astype("u8"))) assert not index.equals(pd.Index(index.asi8).astype("f8")) + + def test_where_cast_str(self): + index = self.create_index() + + mask = np.ones(len(index), dtype=bool) + mask[-1] = False + + result = index.where(mask, str(index[0])) + expected = index.where(mask, index[0]) + tm.assert_index_equal(result, expected) + + result = index.where(mask, [str(index[0])]) + tm.assert_index_equal(result, expected) + + msg = "Where requires matching dtype, not foo" + with pytest.raises(TypeError, match=msg): + index.where(mask, "foo") + + msg = r"Where requires matching dtype, not \['foo'\]" + with pytest.raises(TypeError, match=msg): + index.where(mask, ["foo"])