Skip to content

Commit 95c2a7d

Browse files
authored
TST: Fix shares_memory for arrow string dtype (#55823)
* TST: Fix shares_memory for arrow string dtype * TST: Fix shares_memory for arrow string dtype * TST: Fix shares_memory for arrow string dtype * Fix mypy
1 parent 6493d2a commit 95c2a7d

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

pandas/_testing/__init__.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
is_float_dtype,
3131
is_sequence,
3232
is_signed_integer_dtype,
33+
is_string_dtype,
3334
is_unsigned_integer_dtype,
3435
pandas_dtype,
3536
)
@@ -1055,10 +1056,18 @@ def shares_memory(left, right) -> bool:
10551056
if isinstance(left, pd.core.arrays.IntervalArray):
10561057
return shares_memory(left._left, right) or shares_memory(left._right, right)
10571058

1058-
if isinstance(left, ExtensionArray) and left.dtype == "string[pyarrow]":
1059+
if (
1060+
isinstance(left, ExtensionArray)
1061+
and is_string_dtype(left.dtype)
1062+
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] # noqa: E501
1063+
):
10591064
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
10601065
left = cast("ArrowExtensionArray", left)
1061-
if isinstance(right, ExtensionArray) and right.dtype == "string[pyarrow]":
1066+
if (
1067+
isinstance(right, ExtensionArray)
1068+
and is_string_dtype(right.dtype)
1069+
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] # noqa: E501
1070+
):
10621071
right = cast("ArrowExtensionArray", right)
10631072
left_pa_data = left._pa_array
10641073
right_pa_data = right._pa_array

pandas/tests/util/test_shares_memory.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pandas.util._test_decorators as td
2+
13
import pandas as pd
24
import pandas._testing as tm
35

@@ -11,3 +13,18 @@ def test_shares_memory_interval():
1113
assert tm.shares_memory(obj, obj[:2])
1214

1315
assert not tm.shares_memory(obj, obj._data.copy())
16+
17+
18+
@td.skip_if_no("pyarrow")
19+
def test_shares_memory_string():
20+
# GH#55823
21+
import pyarrow as pa
22+
23+
obj = pd.array(["a", "b"], dtype="string[pyarrow]")
24+
assert tm.shares_memory(obj, obj)
25+
26+
obj = pd.array(["a", "b"], dtype="string[pyarrow_numpy]")
27+
assert tm.shares_memory(obj, obj)
28+
29+
obj = pd.array(["a", "b"], dtype=pd.ArrowDtype(pa.string()))
30+
assert tm.shares_memory(obj, obj)

0 commit comments

Comments
 (0)