Skip to content

Commit e989e31

Browse files
authored
DEPR: try_cast kwarg in mask, where (#38836)
1 parent 6b61b10 commit e989e31

File tree

6 files changed

+56
-30
lines changed

6 files changed

+56
-30
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ Deprecations
155155
- Deprecating allowing scalars passed to the :class:`Categorical` constructor (:issue:`38433`)
156156
- Deprecated allowing subclass-specific keyword arguments in the :class:`Index` constructor, use the specific subclass directly instead (:issue:`14093`,:issue:`21311`,:issue:`22315`,:issue:`26974`)
157157
- Deprecated ``astype`` of datetimelike (``timedelta64[ns]``, ``datetime64[ns]``, ``Datetime64TZDtype``, ``PeriodDtype``) to integer dtypes, use ``values.view(...)`` instead (:issue:`38544`)
158-
-
158+
- Deprecated keyword ``try_cast`` in :meth:`Series.where`, :meth:`Series.mask`, :meth:`DataFrame.where`, :meth:`DataFrame.mask`; cast results manually if desired (:issue:`38836`)
159159
-
160160

161161
.. ---------------------------------------------------------------------------

pandas/core/generic.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -8781,7 +8781,6 @@ def _where(
87818781
axis=None,
87828782
level=None,
87838783
errors="raise",
8784-
try_cast=False,
87858784
):
87868785
"""
87878786
Equivalent to public method `where`, except that `other` is not
@@ -8932,7 +8931,6 @@ def _where(
89328931
cond=cond,
89338932
align=align,
89348933
errors=errors,
8935-
try_cast=try_cast,
89368934
axis=block_axis,
89378935
)
89388936
result = self._constructor(new_data)
@@ -8954,7 +8952,7 @@ def where(
89548952
axis=None,
89558953
level=None,
89568954
errors="raise",
8957-
try_cast=False,
8955+
try_cast=lib.no_default,
89588956
):
89598957
"""
89608958
Replace values where the condition is {cond_rev}.
@@ -8986,9 +8984,12 @@ def where(
89868984
- 'raise' : allow exceptions to be raised.
89878985
- 'ignore' : suppress exceptions. On error return original object.
89888986
8989-
try_cast : bool, default False
8987+
try_cast : bool, default None
89908988
Try to cast the result back to the input type (if possible).
89918989
8990+
.. deprecated:: 1.3.0
8991+
Manually cast back if necessary.
8992+
89928993
Returns
89938994
-------
89948995
Same type as caller or None if ``inplace=True``.
@@ -9077,9 +9078,16 @@ def where(
90779078
4 True True
90789079
"""
90799080
other = com.apply_if_callable(other, self)
9080-
return self._where(
9081-
cond, other, inplace, axis, level, errors=errors, try_cast=try_cast
9082-
)
9081+
9082+
if try_cast is not lib.no_default:
9083+
warnings.warn(
9084+
"try_cast keyword is deprecated and will be removed in a "
9085+
"future version",
9086+
FutureWarning,
9087+
stacklevel=2,
9088+
)
9089+
9090+
return self._where(cond, other, inplace, axis, level, errors=errors)
90839091

90849092
@final
90859093
@doc(
@@ -9098,12 +9106,20 @@ def mask(
90989106
axis=None,
90999107
level=None,
91009108
errors="raise",
9101-
try_cast=False,
9109+
try_cast=lib.no_default,
91029110
):
91039111

91049112
inplace = validate_bool_kwarg(inplace, "inplace")
91059113
cond = com.apply_if_callable(cond, self)
91069114

9115+
if try_cast is not lib.no_default:
9116+
warnings.warn(
9117+
"try_cast keyword is deprecated and will be removed in a "
9118+
"future version",
9119+
FutureWarning,
9120+
stacklevel=2,
9121+
)
9122+
91079123
# see gh-21891
91089124
if not hasattr(cond, "__invert__"):
91099125
cond = np.array(cond)
@@ -9114,7 +9130,6 @@ def mask(
91149130
inplace=inplace,
91159131
axis=axis,
91169132
level=level,
9117-
try_cast=try_cast,
91189133
errors=errors,
91199134
)
91209135

pandas/core/internals/blocks.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -1290,9 +1290,7 @@ def _maybe_reshape_where_args(self, values, other, cond, axis):
12901290

12911291
return other, cond
12921292

1293-
def where(
1294-
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
1295-
) -> List["Block"]:
1293+
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
12961294
"""
12971295
evaluate the block; return result block(s) from the result
12981296
@@ -1303,7 +1301,6 @@ def where(
13031301
errors : str, {'raise', 'ignore'}, default 'raise'
13041302
- ``raise`` : allow exceptions to be raised
13051303
- ``ignore`` : suppress exceptions. On error return original object
1306-
try_cast: bool, default False
13071304
axis : int, default 0
13081305
13091306
Returns
@@ -1342,9 +1339,7 @@ def where(
13421339
# we cannot coerce, return a compat dtype
13431340
# we are explicitly ignoring errors
13441341
block = self.coerce_to_target_dtype(other)
1345-
blocks = block.where(
1346-
orig_other, cond, errors=errors, try_cast=try_cast, axis=axis
1347-
)
1342+
blocks = block.where(orig_other, cond, errors=errors, axis=axis)
13481343
return self._maybe_downcast(blocks, "infer")
13491344

13501345
if not (
@@ -1825,9 +1820,7 @@ def shift(
18251820
)
18261821
]
18271822

1828-
def where(
1829-
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
1830-
) -> List["Block"]:
1823+
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
18311824

18321825
cond = _extract_bool_array(cond)
18331826
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
@@ -2075,9 +2068,7 @@ def to_native_types(self, na_rep="NaT", **kwargs):
20752068
result = arr._format_native_types(na_rep=na_rep, **kwargs)
20762069
return self.make_block(result)
20772070

2078-
def where(
2079-
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
2080-
) -> List["Block"]:
2071+
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
20812072
# TODO(EA2D): reshape unnecessary with 2D EAs
20822073
arr = self.array_values().reshape(self.shape)
20832074

@@ -2086,9 +2077,7 @@ def where(
20862077
try:
20872078
res_values = arr.T.where(cond, other).T
20882079
except (ValueError, TypeError):
2089-
return super().where(
2090-
other, cond, errors=errors, try_cast=try_cast, axis=axis
2091-
)
2080+
return super().where(other, cond, errors=errors, axis=axis)
20922081

20932082
# TODO(EA2D): reshape not needed with 2D EAs
20942083
res_values = res_values.reshape(self.values.shape)

pandas/core/internals/managers.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,7 @@ def get_axe(block, qs, axes):
542542
def isna(self, func) -> "BlockManager":
543543
return self.apply("apply", func=func)
544544

545-
def where(
546-
self, other, cond, align: bool, errors: str, try_cast: bool, axis: int
547-
) -> "BlockManager":
545+
def where(self, other, cond, align: bool, errors: str, axis: int) -> "BlockManager":
548546
if align:
549547
align_keys = ["other", "cond"]
550548
else:
@@ -557,7 +555,6 @@ def where(
557555
other=other,
558556
cond=cond,
559557
errors=errors,
560-
try_cast=try_cast,
561558
axis=axis,
562559
)
563560

pandas/tests/frame/indexing/test_mask.py

+13
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,16 @@ def test_mask_dtype_conversion(self):
8383
expected = bools.astype(float).mask(mask)
8484
result = bools.mask(mask)
8585
tm.assert_frame_equal(result, expected)
86+
87+
88+
def test_mask_try_cast_deprecated(frame_or_series):
89+
90+
obj = DataFrame(np.random.randn(4, 3))
91+
if frame_or_series is not DataFrame:
92+
obj = obj[0]
93+
94+
mask = obj > 0
95+
96+
with tm.assert_produces_warning(FutureWarning):
97+
# try_cast keyword deprecated
98+
obj.mask(mask, -1, try_cast=True)

pandas/tests/frame/indexing/test_where.py

+12
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,15 @@ def test_where_ea_other(self):
672672
expected["B"] = expected["B"].astype(object)
673673
result = df.where(mask, ser2, axis=1)
674674
tm.assert_frame_equal(result, expected)
675+
676+
677+
def test_where_try_cast_deprecated(frame_or_series):
678+
obj = DataFrame(np.random.randn(4, 3))
679+
if frame_or_series is not DataFrame:
680+
obj = obj[0]
681+
682+
mask = obj > 0
683+
684+
with tm.assert_produces_warning(FutureWarning):
685+
# try_cast keyword deprecated
686+
obj.where(mask, -1, try_cast=False)

0 commit comments

Comments
 (0)