From 0062bc08499cf5fe8a8cbaf2f92107ec434d60d6 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 30 Sep 2020 11:32:15 -0700 Subject: [PATCH 1/2] Standardize cast_str behavior in all datetimelike fill_value validators --- pandas/core/arrays/datetimelike.py | 8 ++++---- pandas/tests/arrays/test_datetimelike.py | 10 ++++++++++ pandas/tests/indexes/datetimelike.py | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 0f723546fb4c2..60b6b3caa695a 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -721,7 +721,7 @@ def _validate_fill_value(self, fill_value): f"Got '{str(fill_value)}'." ) try: - fill_value = self._validate_scalar(fill_value, msg) + fill_value = self._validate_scalar(fill_value, msg, cast_str=True) except TypeError as err: raise ValueError(msg) from err rv = self._unbox(fill_value) @@ -858,7 +858,7 @@ def _validate_setitem_value(self, value): def _validate_insert_value(self, value): msg = f"cannot insert {type(self).__name__} with incompatible label" - value = self._validate_scalar(value, msg, cast_str=False) + value = self._validate_scalar(value, msg, cast_str=True) self._check_compatible_with(value, setitem=True) # TODO: if we dont have compat, should we raise or astype(object)? @@ -870,9 +870,9 @@ def _validate_insert_value(self, value): def _validate_where_value(self, other): msg = f"Where requires matching dtype, not {type(other)}" if not is_list_like(other): - other = self._validate_scalar(other, msg) + other = self._validate_scalar(other, msg, cast_str=True) else: - other = self._validate_listlike(other, "where") + other = self._validate_listlike(other, "where", cast_str=True) return self._unbox(other, setitem=True) diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 3f5ab5baa7d69..91bcdf32603f4 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -160,6 +160,16 @@ def test_take_fill(self): result = arr.take([-1, 1], allow_fill=True, fill_value=pd.NaT) assert result[0] is pd.NaT + def test_take_fill_str(self, arr1d): + # Cast str fill_value matching other fill_value-taking methods + result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1])) + expected = arr1d[[-1, 1]] + tm.assert_equal(result, expected) + + msg = r"'fill_value' should be a <.*>\. Got 'foo'" + with pytest.raises(ValueError, match=msg): + arr1d.take([-1, 1], allow_fill=True, fill_value="foo") + def test_concat_same_type(self): data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 diff --git a/pandas/tests/indexes/datetimelike.py b/pandas/tests/indexes/datetimelike.py index f667e5a610419..7189a7573241d 100644 --- a/pandas/tests/indexes/datetimelike.py +++ b/pandas/tests/indexes/datetimelike.py @@ -108,3 +108,24 @@ def test_getitem_preserves_freq(self): result = index[:] assert result.freq == index.freq + + def test_where_cast_str(self): + index = self.create_index() + + mask = np.ones(len(index), dtype=bool) + mask[-1] = False + + result = index.where(mask, str(index[0])) + expected = index.where(mask, index[0]) + tm.assert_index_equal(result, expected) + + result = index.where(mask, [str(index[0])]) + tm.assert_index_equal(result, expected) + + msg = "Where requires matching dtype, not foo" + with pytest.raises(TypeError, match=msg): + index.where(mask, "foo") + + msg = r"Where requires matching dtype, not \['foo'\]" + with pytest.raises(TypeError, match=msg): + index.where(mask, ["foo"]) From d7745b2e827a3987d6a7e52ff2da5a407e71de0d Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 2 Oct 2020 14:53:11 -0700 Subject: [PATCH 2/2] CLN: remove cast_str kwarg --- pandas/core/arrays/datetimelike.py | 30 ++++++++++++----------------- pandas/core/indexes/datetimelike.py | 2 +- pandas/core/indexes/timedeltas.py | 2 +- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 83cf8ac2a4090..ef402fce642b9 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -723,7 +723,7 @@ def _validate_fill_value(self, fill_value): f"Got '{str(fill_value)}'." ) try: - fill_value = self._validate_scalar(fill_value, msg, cast_str=True) + fill_value = self._validate_scalar(fill_value, msg) except TypeError as err: raise ValueError(msg) from err rv = self._unbox(fill_value) @@ -756,9 +756,7 @@ def _validate_shift_value(self, fill_value): return self._unbox(fill_value) - def _validate_scalar( - self, value, msg: Optional[str] = None, cast_str: bool = False - ): + def _validate_scalar(self, value, msg: Optional[str] = None): """ Validate that the input value can be cast to our scalar_type. @@ -769,14 +767,12 @@ def _validate_scalar( Message to raise in TypeError on invalid input. If not provided, `value` is cast to a str and used as the message. - cast_str : bool, default False - Whether to try to parse string input to scalar_type. Returns ------- self._scalar_type or NaT """ - if cast_str and isinstance(value, str): + if isinstance(value, str): # NB: Careful about tzawareness try: value = self._scalar_from_string(value) @@ -798,9 +794,7 @@ def _validate_scalar( return value - def _validate_listlike( - self, value, opname: str, cast_str: bool = False, allow_object: bool = False - ): + def _validate_listlike(self, value, opname: str, allow_object: bool = False): if isinstance(value, type(self)): return value @@ -809,7 +803,7 @@ def _validate_listlike( value = array(value) value = extract_array(value, extract_numpy=True) - if cast_str and is_dtype_equal(value.dtype, "string"): + if is_dtype_equal(value.dtype, "string"): # We got a StringArray try: # TODO: Could use from_sequence_of_strings if implemented @@ -839,9 +833,9 @@ def _validate_listlike( def _validate_searchsorted_value(self, value): msg = "searchsorted requires compatible dtype or scalar" if not is_list_like(value): - value = self._validate_scalar(value, msg, cast_str=True) + value = self._validate_scalar(value, msg) else: - value = self._validate_listlike(value, "searchsorted", cast_str=True) + value = self._validate_listlike(value, "searchsorted") rv = self._unbox(value) return self._rebox_native(rv) @@ -852,15 +846,15 @@ def _validate_setitem_value(self, value): f"or array of those. Got '{type(value).__name__}' instead." ) if is_list_like(value): - value = self._validate_listlike(value, "setitem", cast_str=True) + value = self._validate_listlike(value, "setitem") else: - value = self._validate_scalar(value, msg, cast_str=True) + value = self._validate_scalar(value, msg) return self._unbox(value, setitem=True) def _validate_insert_value(self, value): msg = f"cannot insert {type(self).__name__} with incompatible label" - value = self._validate_scalar(value, msg, cast_str=True) + value = self._validate_scalar(value, msg) self._check_compatible_with(value, setitem=True) # TODO: if we dont have compat, should we raise or astype(object)? @@ -872,9 +866,9 @@ def _validate_insert_value(self, value): def _validate_where_value(self, other): msg = f"Where requires matching dtype, not {type(other)}" if not is_list_like(other): - other = self._validate_scalar(other, msg, cast_str=True) + other = self._validate_scalar(other, msg) else: - other = self._validate_listlike(other, "where", cast_str=True) + other = self._validate_listlike(other, "where") return self._unbox(other, setitem=True) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index d2162d987ccd6..5128c644e6bcb 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -646,7 +646,7 @@ def _wrap_joined_index(self, joined: np.ndarray, other): def _convert_arr_indexer(self, keyarr): try: return self._data._validate_listlike( - keyarr, "convert_arr_indexer", cast_str=True, allow_object=True + keyarr, "convert_arr_indexer", allow_object=True ) except (ValueError, TypeError): return com.asarray_tuplesafe(keyarr) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 854c4e33eca01..7e635e55288e5 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -217,7 +217,7 @@ def get_loc(self, key, method=None, tolerance=None): raise InvalidIndexError(key) try: - key = self._data._validate_scalar(key, cast_str=True) + key = self._data._validate_scalar(key) except TypeError as err: raise KeyError(key) from err