Skip to content

Commit fa1e36c

Browse files
phoflmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#54826: Infer large_string type as pyarrow_numpy strings
1 parent ed1f044 commit fa1e36c

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

pandas/core/arrays/string_arrow.py

+9
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,15 @@ def _str_rstrip(self, to_strip=None):
448448
class ArrowStringArrayNumpySemantics(ArrowStringArray):
449449
_storage = "pyarrow_numpy"
450450

451+
def __init__(self, values) -> None:
452+
_chk_pyarrow_available()
453+
454+
if isinstance(values, (pa.Array, pa.ChunkedArray)) and pa.types.is_large_string(
455+
values.type
456+
):
457+
values = pc.cast(values, pa.string())
458+
super().__init__(values)
459+
451460
@classmethod
452461
def _result_converter(cls, values, na=None):
453462
if not isna(na):

pandas/io/_util.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,7 @@ def _arrow_dtype_mapping() -> dict:
2828
def arrow_string_types_mapper() -> Callable:
2929
pa = import_optional_dependency("pyarrow")
3030

31-
return {pa.string(): pd.StringDtype(storage="pyarrow_numpy")}.get
31+
return {
32+
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
33+
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
34+
}.get

pandas/tests/arrays/string_/test_string_arrow.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
StringArray,
1313
StringDtype,
1414
)
15-
from pandas.core.arrays.string_arrow import ArrowStringArray
15+
from pandas.core.arrays.string_arrow import (
16+
ArrowStringArray,
17+
ArrowStringArrayNumpySemantics,
18+
)
1619

1720
skip_if_no_pyarrow = pytest.mark.skipif(
1821
pa_version_under7p0,
@@ -166,6 +169,9 @@ def test_pyarrow_not_installed_raises():
166169
with pytest.raises(ImportError, match=msg):
167170
ArrowStringArray([])
168171

172+
with pytest.raises(ImportError, match=msg):
173+
ArrowStringArrayNumpySemantics([])
174+
169175
with pytest.raises(ImportError, match=msg):
170176
ArrowStringArray._from_sequence(["a", None, "b"])
171177

pandas/tests/io/test_parquet.py

+19
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,25 @@ def test_roundtrip_decimal(self, tmp_path, pa):
11251125
expected = pd.DataFrame({"a": ["123"]}, dtype="string[python]")
11261126
tm.assert_frame_equal(result, expected)
11271127

1128+
def test_infer_string_large_string_type(self, tmp_path, pa):
1129+
# GH#54798
1130+
import pyarrow as pa
1131+
import pyarrow.parquet as pq
1132+
1133+
path = tmp_path / "large_string.p"
1134+
1135+
table = pa.table({"a": pa.array([None, "b", "c"], pa.large_string())})
1136+
pq.write_table(table, path)
1137+
1138+
with pd.option_context("future.infer_string", True):
1139+
result = read_parquet(path)
1140+
expected = pd.DataFrame(
1141+
data={"a": [None, "b", "c"]},
1142+
dtype="string[pyarrow_numpy]",
1143+
columns=pd.Index(["a"], dtype="string[pyarrow_numpy]"),
1144+
)
1145+
tm.assert_frame_equal(result, expected)
1146+
11281147

11291148
class TestParquetFastParquet(Base):
11301149
def test_basic(self, fp, df_full):

0 commit comments

Comments
 (0)