Skip to content

Commit 7bdf69f

Browse files
TST: add CoW tests for xs() and get() (#51292)
1 parent 94f9412 commit 7bdf69f

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

pandas/tests/copy_view/test_methods.py

+90
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.errors import SettingWithCopyWarning
5+
6+
import pandas as pd
47
from pandas import (
58
DataFrame,
69
Index,
@@ -1308,3 +1311,90 @@ def test_isetitem(using_copy_on_write):
13081311
assert np.shares_memory(get_array(df, "c"), get_array(df2, "c"))
13091312
else:
13101313
assert not np.shares_memory(get_array(df, "c"), get_array(df2, "c"))
1314+
1315+
1316+
@pytest.mark.parametrize("key", ["a", ["a"]])
1317+
def test_get(using_copy_on_write, key):
1318+
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
1319+
df_orig = df.copy()
1320+
1321+
result = df.get(key)
1322+
1323+
if using_copy_on_write:
1324+
assert np.shares_memory(get_array(result, "a"), get_array(df, "a"))
1325+
result.iloc[0] = 0
1326+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
1327+
tm.assert_frame_equal(df, df_orig)
1328+
else:
1329+
# for non-CoW it depends on whether we got a Series or DataFrame if it
1330+
# is a view or copy or triggers a warning or not
1331+
warn = SettingWithCopyWarning if isinstance(key, list) else None
1332+
with pd.option_context("chained_assignment", "warn"):
1333+
with tm.assert_produces_warning(warn):
1334+
result.iloc[0] = 0
1335+
1336+
if isinstance(key, list):
1337+
tm.assert_frame_equal(df, df_orig)
1338+
else:
1339+
assert df.iloc[0, 0] == 0
1340+
1341+
1342+
@pytest.mark.parametrize("axis, key", [(0, 0), (1, "a")])
1343+
@pytest.mark.parametrize(
1344+
"dtype", ["int64", "float64"], ids=["single-block", "mixed-block"]
1345+
)
1346+
def test_xs(using_copy_on_write, using_array_manager, axis, key, dtype):
1347+
single_block = (dtype == "int64") and not using_array_manager
1348+
is_view = single_block or (using_array_manager and axis == 1)
1349+
df = DataFrame(
1350+
{"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)}
1351+
)
1352+
df_orig = df.copy()
1353+
1354+
result = df.xs(key, axis=axis)
1355+
1356+
if axis == 1 or single_block:
1357+
assert np.shares_memory(get_array(df, "a"), get_array(result))
1358+
elif using_copy_on_write:
1359+
assert result._mgr._has_no_reference(0)
1360+
1361+
if using_copy_on_write or is_view:
1362+
result.iloc[0] = 0
1363+
else:
1364+
with pd.option_context("chained_assignment", "warn"):
1365+
with tm.assert_produces_warning(SettingWithCopyWarning):
1366+
result.iloc[0] = 0
1367+
1368+
if using_copy_on_write or (not single_block and axis == 0):
1369+
tm.assert_frame_equal(df, df_orig)
1370+
else:
1371+
assert df.iloc[0, 0] == 0
1372+
1373+
1374+
@pytest.mark.parametrize("axis", [0, 1])
1375+
@pytest.mark.parametrize("key, level", [("l1", 0), (2, 1)])
1376+
def test_xs_multiindex(using_copy_on_write, using_array_manager, key, level, axis):
1377+
arr = np.arange(18).reshape(6, 3)
1378+
index = MultiIndex.from_product([["l1", "l2"], [1, 2, 3]], names=["lev1", "lev2"])
1379+
df = DataFrame(arr, index=index, columns=list("abc"))
1380+
if axis == 1:
1381+
df = df.transpose().copy()
1382+
df_orig = df.copy()
1383+
1384+
result = df.xs(key, level=level, axis=axis)
1385+
1386+
if level == 0:
1387+
assert np.shares_memory(
1388+
get_array(df, df.columns[0]), get_array(result, result.columns[0])
1389+
)
1390+
1391+
warn = (
1392+
SettingWithCopyWarning
1393+
if not using_copy_on_write and not using_array_manager
1394+
else None
1395+
)
1396+
with pd.option_context("chained_assignment", "warn"):
1397+
with tm.assert_produces_warning(warn):
1398+
result.iloc[0, 0] = 0
1399+
1400+
tm.assert_frame_equal(df, df_orig)

pandas/tests/copy_view/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_array(obj, col=None):
1010
which triggers tracking references / CoW (and we might be testing that
1111
this is done by some other operation).
1212
"""
13-
if isinstance(obj, Series) and (obj is None or obj.name == col):
13+
if isinstance(obj, Series) and (col is None or obj.name == col):
1414
return obj._values
1515
assert col is not None
1616
icol = obj.columns.get_loc(col)

0 commit comments

Comments
 (0)