diff --git a/pandas/core/computation/expressions.py b/pandas/core/computation/expressions.py index ada983e9e4fad..fdc299ccdfde8 100644 --- a/pandas/core/computation/expressions.py +++ b/pandas/core/computation/expressions.py @@ -12,8 +12,6 @@ from pandas._config import get_option -from pandas._libs.lib import values_from_object - from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.computation.check import _NUMEXPR_INSTALLED @@ -123,26 +121,19 @@ def _evaluate_numexpr(op, op_str, a, b): def _where_standard(cond, a, b): - return np.where( - values_from_object(cond), values_from_object(a), values_from_object(b) - ) + # Caller is responsible for calling values_from_object if necessary + return np.where(cond, a, b) def _where_numexpr(cond, a, b): + # Caller is responsible for calling values_from_object if necessary result = None if _can_use_numexpr(None, "where", a, b, "where"): - cond_value = getattr(cond, "values", cond) - a_value = getattr(a, "values", a) - b_value = getattr(b, "values", b) result = ne.evaluate( "where(cond_value, a_value, b_value)", - local_dict={ - "cond_value": cond_value, - "a_value": a_value, - "b_value": b_value, - }, + local_dict={"cond_value": cond, "a_value": a, "b_value": b}, casting="safe", ) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index a93211edf162b..80369081f96d6 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1379,8 +1379,7 @@ def where( if not hasattr(cond, "shape"): raise ValueError("where must have a condition that is ndarray like") - # our where function - def func(cond, values, other): + def where_func(cond, values, other): if not ( (self.is_integer or self.is_bool) @@ -1391,8 +1390,11 @@ def func(cond, values, other): if not self._can_hold_element(other): raise TypeError if lib.is_scalar(other) and isinstance(values, np.ndarray): + # convert datetime to datetime64, timedelta to timedelta64 other = convert_scalar(values, other) + # By the time we get here, we should have all Series/Index + # args extracted to ndarray fastres = expressions.where(cond, values, other) return fastres @@ -1402,7 +1404,7 @@ def func(cond, values, other): # see if we can operate on the entire block, or need item-by-item # or if we are a single block (ndim == 1) try: - result = func(cond, values, other) + result = where_func(cond, values, other) except TypeError: # we cannot coerce, return a compat dtype