Skip to content

Commit 014d8ea

Browse files
authored
BUG: flex op with DataFrame, Series and ea vs ndarray (pandas-dev#34277)
1 parent 238e04f commit 014d8ea

File tree

3 files changed

+63
-27
lines changed

3 files changed

+63
-27
lines changed

pandas/core/ops/__init__.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555

5656
if TYPE_CHECKING:
57-
from pandas import DataFrame # noqa:F401
57+
from pandas import DataFrame, Series # noqa:F401
5858

5959
# -----------------------------------------------------------------------------
6060
# constants
@@ -459,19 +459,7 @@ def _combine_series_frame(left, right, func, axis: int):
459459
# We assume that self.align(other, ...) has already been called
460460

461461
rvalues = right._values
462-
if isinstance(rvalues, np.ndarray):
463-
# TODO(EA2D): no need to special-case with 2D EAs
464-
# We can operate block-wise
465-
if axis == 0:
466-
rvalues = rvalues.reshape(-1, 1)
467-
else:
468-
rvalues = rvalues.reshape(1, -1)
469-
470-
rvalues = np.broadcast_to(rvalues, left.shape)
471-
472-
array_op = get_array_op(func)
473-
bm = left._mgr.apply(array_op, right=rvalues.T, align_keys=["right"])
474-
return type(left)(bm)
462+
assert not isinstance(rvalues, np.ndarray) # handled by align_series_as_frame
475463

476464
if axis == 0:
477465
new_data = dispatch_to_series(left, right, func)
@@ -567,6 +555,7 @@ def to_series(right):
567555
left, right = left.align(
568556
right, join="outer", axis=axis, level=level, copy=False
569557
)
558+
right = _maybe_align_series_as_frame(left, right, axis)
570559

571560
return left, right
572561

@@ -627,6 +616,25 @@ def _frame_arith_method_with_reindex(
627616
return result.reindex(join_columns, axis=1)
628617

629618

619+
def _maybe_align_series_as_frame(frame: "DataFrame", series: "Series", axis: int):
620+
"""
621+
If the Series operand is not EA-dtype, we can broadcast to 2D and operate
622+
blockwise.
623+
"""
624+
rvalues = series._values
625+
if not isinstance(rvalues, np.ndarray):
626+
# TODO(EA2D): no need to special-case with 2D EAs
627+
return series
628+
629+
if axis == 0:
630+
rvalues = rvalues.reshape(-1, 1)
631+
else:
632+
rvalues = rvalues.reshape(1, -1)
633+
634+
rvalues = np.broadcast_to(rvalues, frame.shape)
635+
return type(frame)(rvalues, index=frame.index, columns=frame.columns)
636+
637+
630638
def _arith_method_FRAME(cls: Type["DataFrame"], op, special: bool):
631639
# This is the only function where `special` can be either True or False
632640
op_name = _get_op_name(op, special)
@@ -648,6 +656,11 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
648656
):
649657
return _frame_arith_method_with_reindex(self, other, op)
650658

659+
if isinstance(other, ABCSeries) and fill_value is not None:
660+
# TODO: We could allow this in cases where we end up going
661+
# through the DataFrame path
662+
raise NotImplementedError(f"fill_value {fill_value} not supported.")
663+
651664
# TODO: why are we passing flex=True instead of flex=not special?
652665
# 15 tests fail if we pass flex=not special instead
653666
self, other = _align_method_FRAME(self, other, axis, flex=True, level=level)
@@ -657,9 +670,6 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
657670
new_data = self._combine_frame(other, na_op, fill_value)
658671

659672
elif isinstance(other, ABCSeries):
660-
if fill_value is not None:
661-
raise NotImplementedError(f"fill_value {fill_value} not supported.")
662-
663673
axis = self._get_axis_number(axis) if axis is not None else 1
664674
new_data = _combine_series_frame(self, other, op, axis=axis)
665675
else:

pandas/tests/arithmetic/test_timedelta64.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1470,8 +1470,6 @@ def test_td64arr_add_sub_object_array(self, box_with_array):
14701470
[pd.Timedelta(days=2), pd.Timedelta(days=4), pd.Timestamp("2000-01-07")]
14711471
)
14721472
expected = tm.box_expected(expected, box_with_array)
1473-
if box_with_array is pd.DataFrame:
1474-
expected = expected.astype(object)
14751473
tm.assert_equal(result, expected)
14761474

14771475
msg = "unsupported operand type|cannot subtract a datelike"
@@ -1486,8 +1484,6 @@ def test_td64arr_add_sub_object_array(self, box_with_array):
14861484
[pd.Timedelta(0), pd.Timedelta(0), pd.Timestamp("2000-01-01")]
14871485
)
14881486
expected = tm.box_expected(expected, box_with_array)
1489-
if box_with_array is pd.DataFrame:
1490-
expected = expected.astype(object)
14911487
tm.assert_equal(result, expected)
14921488

14931489

@@ -2012,7 +2008,7 @@ def test_td64arr_div_numeric_array(self, box_with_array, vector, any_real_dtype)
20122008
tm.assert_equal(result, expected)
20132009

20142010
pattern = (
2015-
"true_divide cannot use operands|"
2011+
"true_divide'? cannot use operands|"
20162012
"cannot perform __div__|"
20172013
"cannot perform __truediv__|"
20182014
"unsupported operand|"

pandas/tests/frame/test_arithmetic.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,20 @@ def test_df_flex_cmp_constant_return_types_empty(self, opname):
339339
result = getattr(empty, opname)(const).dtypes.value_counts()
340340
tm.assert_series_equal(result, pd.Series([2], index=[np.dtype(bool)]))
341341

342+
def test_df_flex_cmp_ea_dtype_with_ndarray_series(self):
343+
ii = pd.IntervalIndex.from_breaks([1, 2, 3])
344+
df = pd.DataFrame({"A": ii, "B": ii})
345+
346+
ser = pd.Series([0, 0])
347+
res = df.eq(ser, axis=0)
348+
349+
expected = pd.DataFrame({"A": [False, False], "B": [False, False]})
350+
tm.assert_frame_equal(res, expected)
351+
352+
ser2 = pd.Series([1, 2], index=["A", "B"])
353+
res2 = df.eq(ser2, axis=1)
354+
tm.assert_frame_equal(res2, expected)
355+
342356

343357
# -------------------------------------------------------------------
344358
# Arithmetic
@@ -1410,12 +1424,13 @@ def test_alignment_non_pandas(self):
14101424
range(1, 4),
14111425
]:
14121426

1413-
tm.assert_series_equal(
1414-
align(df, val, "index")[1], Series([1, 2, 3], index=df.index)
1415-
)
1416-
tm.assert_series_equal(
1417-
align(df, val, "columns")[1], Series([1, 2, 3], index=df.columns)
1427+
expected = DataFrame({"X": val, "Y": val, "Z": val}, index=df.index)
1428+
tm.assert_frame_equal(align(df, val, "index")[1], expected)
1429+
1430+
expected = DataFrame(
1431+
{"X": [1, 1, 1], "Y": [2, 2, 2], "Z": [3, 3, 3]}, index=df.index
14181432
)
1433+
tm.assert_frame_equal(align(df, val, "columns")[1], expected)
14191434

14201435
# length mismatch
14211436
msg = "Unable to coerce to Series, length must be 3: given 2"
@@ -1484,3 +1499,18 @@ def test_pow_nan_with_zero():
14841499

14851500
result = left["A"] ** right["A"]
14861501
tm.assert_series_equal(result, expected["A"])
1502+
1503+
1504+
def test_dataframe_series_extension_dtypes():
1505+
# https://github.com/pandas-dev/pandas/issues/34311
1506+
df = pd.DataFrame(np.random.randint(0, 100, (10, 3)), columns=["a", "b", "c"])
1507+
ser = pd.Series([1, 2, 3], index=["a", "b", "c"])
1508+
1509+
expected = df.to_numpy("int64") + ser.to_numpy("int64").reshape(-1, 3)
1510+
expected = pd.DataFrame(expected, columns=df.columns, dtype="Int64")
1511+
1512+
df_ea = df.astype("Int64")
1513+
result = df_ea + ser
1514+
tm.assert_frame_equal(result, expected)
1515+
result = df_ea + ser.astype("Int64")
1516+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)