|
89 | 89 | import pandas as pd
|
90 | 90 | from pandas.core import arraylike, indexing, missing, nanops
|
91 | 91 | import pandas.core.algorithms as algos
|
| 92 | +from pandas.core.arrays import ExtensionArray |
92 | 93 | from pandas.core.base import PandasObject, SelectionMixin
|
93 | 94 | import pandas.core.common as com
|
94 |
| -from pandas.core.construction import create_series_with_explicit_dtype |
| 95 | +from pandas.core.construction import create_series_with_explicit_dtype, extract_array |
95 | 96 | from pandas.core.flags import Flags
|
96 | 97 | from pandas.core.indexes import base as ibase
|
97 | 98 | from pandas.core.indexes.api import (
|
@@ -8788,6 +8789,9 @@ def _where(
|
8788 | 8789 | """
|
8789 | 8790 | inplace = validate_bool_kwarg(inplace, "inplace")
|
8790 | 8791 |
|
| 8792 | + if axis is not None: |
| 8793 | + axis = self._get_axis_number(axis) |
| 8794 | + |
8791 | 8795 | # align the cond to same shape as myself
|
8792 | 8796 | cond = com.apply_if_callable(cond, self)
|
8793 | 8797 | if isinstance(cond, NDFrame):
|
@@ -8827,22 +8831,39 @@ def _where(
|
8827 | 8831 | if other.ndim <= self.ndim:
|
8828 | 8832 |
|
8829 | 8833 | _, other = self.align(
|
8830 |
| - other, join="left", axis=axis, level=level, fill_value=np.nan |
| 8834 | + other, |
| 8835 | + join="left", |
| 8836 | + axis=axis, |
| 8837 | + level=level, |
| 8838 | + fill_value=np.nan, |
| 8839 | + copy=False, |
8831 | 8840 | )
|
8832 | 8841 |
|
8833 | 8842 | # if we are NOT aligned, raise as we cannot where index
|
8834 |
| - if axis is None and not all( |
8835 |
| - other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes) |
8836 |
| - ): |
| 8843 | + if axis is None and not other._indexed_same(self): |
8837 | 8844 | raise InvalidIndexError
|
8838 | 8845 |
|
| 8846 | + elif other.ndim < self.ndim: |
| 8847 | + # TODO(EA2D): avoid object-dtype cast in EA case GH#38729 |
| 8848 | + other = other._values |
| 8849 | + if axis == 0: |
| 8850 | + other = np.reshape(other, (-1, 1)) |
| 8851 | + elif axis == 1: |
| 8852 | + other = np.reshape(other, (1, -1)) |
| 8853 | + |
| 8854 | + other = np.broadcast_to(other, self.shape) |
| 8855 | + |
8839 | 8856 | # slice me out of the other
|
8840 | 8857 | else:
|
8841 | 8858 | raise NotImplementedError(
|
8842 | 8859 | "cannot align with a higher dimensional NDFrame"
|
8843 | 8860 | )
|
8844 | 8861 |
|
8845 |
| - if isinstance(other, np.ndarray): |
| 8862 | + if not isinstance(other, (MultiIndex, NDFrame)): |
| 8863 | + # mainly just catching Index here |
| 8864 | + other = extract_array(other, extract_numpy=True) |
| 8865 | + |
| 8866 | + if isinstance(other, (np.ndarray, ExtensionArray)): |
8846 | 8867 |
|
8847 | 8868 | if other.shape != self.shape:
|
8848 | 8869 |
|
@@ -8887,10 +8908,10 @@ def _where(
|
8887 | 8908 | else:
|
8888 | 8909 | align = self._get_axis_number(axis) == 1
|
8889 | 8910 |
|
8890 |
| - if align and isinstance(other, NDFrame): |
8891 |
| - other = other.reindex(self._info_axis, axis=self._info_axis_number) |
8892 | 8911 | if isinstance(cond, NDFrame):
|
8893 |
| - cond = cond.reindex(self._info_axis, axis=self._info_axis_number) |
| 8912 | + cond = cond.reindex( |
| 8913 | + self._info_axis, axis=self._info_axis_number, copy=False |
| 8914 | + ) |
8894 | 8915 |
|
8895 | 8916 | block_axis = self._get_block_manager_axis(axis)
|
8896 | 8917 |
|
|
0 commit comments