diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 6c5154b90d614..37013a5d1fb8f 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -134,7 +134,7 @@ def merge( right_index: bool = False, sort: bool = False, suffixes: Suffixes = ("_x", "_y"), - copy: bool = True, + copy: bool | None = None, indicator: str | bool = False, validate: str | None = None, ) -> DataFrame: @@ -744,7 +744,7 @@ def _reindex_and_concat( join_index: Index, left_indexer: npt.NDArray[np.intp] | None, right_indexer: npt.NDArray[np.intp] | None, - copy: bool, + copy: bool | None, ) -> DataFrame: """ reindex along index and concat along columns. @@ -793,7 +793,7 @@ def _reindex_and_concat( result = concat([left, right], axis=1, copy=copy) return result - def get_result(self, copy: bool = True) -> DataFrame: + def get_result(self, copy: bool | None = True) -> DataFrame: if self.indicator: self.left, self.right = self._indicator_pre_merge(self.left, self.right) @@ -1800,7 +1800,7 @@ def __init__( sort=True, # factorize sorts ) - def get_result(self, copy: bool = True) -> DataFrame: + def get_result(self, copy: bool | None = True) -> DataFrame: join_index, left_indexer, right_indexer = self._get_join_info() llabels, rlabels = _items_overlap_with_suffix( diff --git a/pandas/tests/copy_view/test_functions.py b/pandas/tests/copy_view/test_functions.py index 569cbc4ad7583..ffc80d2d11798 100644 --- a/pandas/tests/copy_view/test_functions.py +++ b/pandas/tests/copy_view/test_functions.py @@ -1,9 +1,12 @@ import numpy as np +import pandas.util._test_decorators as td + from pandas import ( DataFrame, Series, concat, + merge, ) import pandas._testing as tm from pandas.tests.copy_view.util import get_array @@ -177,3 +180,63 @@ def test_concat_mixed_series_frame(using_copy_on_write): if using_copy_on_write: assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) tm.assert_frame_equal(result, expected) + + +@td.skip_copy_on_write_not_yet_implemented # TODO(CoW) +def test_merge_on_key(using_copy_on_write): + df1 = DataFrame({"key": ["a", "b", "c"], "a": [1, 2, 3]}) + df2 = DataFrame({"key": ["a", "b", "c"], "b": [4, 5, 6]}) + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = merge(df1, df2, on="key") + + if using_copy_on_write: + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + assert not np.shares_memory(get_array(result, "key"), get_array(df1, "key")) + assert not np.shares_memory(get_array(result, "key"), get_array(df2, "key")) + else: + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 1] = 0 + if using_copy_on_write: + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 2] = 0 + if using_copy_on_write: + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +def test_merge_on_index(using_copy_on_write): + df1 = DataFrame({"a": [1, 2, 3]}) + df2 = DataFrame({"b": [4, 5, 6]}) + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = merge(df1, df2, left_index=True, right_index=True) + + if using_copy_on_write: + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + else: + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 0] = 0 + if using_copy_on_write: + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 1] = 0 + if using_copy_on_write: + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +# TODO(CoW) add merge tests where one of left/right isn't copied