Skip to content

Commit 1539526

Browse files
authored
Infer large_string type as pyarrow_numpy strings (#54826)
1 parent 7688d52 commit 1539526

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

pandas/core/arrays/string_arrow.py

+9
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,15 @@ def _str_rstrip(self, to_strip=None):
450450
class ArrowStringArrayNumpySemantics(ArrowStringArray):
451451
_storage = "pyarrow_numpy"
452452

453+
def __init__(self, values) -> None:
454+
_chk_pyarrow_available()
455+
456+
if isinstance(values, (pa.Array, pa.ChunkedArray)) and pa.types.is_large_string(
457+
values.type
458+
):
459+
values = pc.cast(values, pa.string())
460+
super().__init__(values)
461+
453462
@classmethod
454463
def _result_converter(cls, values, na=None):
455464
if not isna(na):

pandas/io/_util.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,7 @@ def _arrow_dtype_mapping() -> dict:
2828
def arrow_string_types_mapper() -> Callable:
2929
pa = import_optional_dependency("pyarrow")
3030

31-
return {pa.string(): pd.StringDtype(storage="pyarrow_numpy")}.get
31+
return {
32+
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
33+
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
34+
}.get

pandas/tests/arrays/string_/test_string_arrow.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
StringArray,
1313
StringDtype,
1414
)
15-
from pandas.core.arrays.string_arrow import ArrowStringArray
15+
from pandas.core.arrays.string_arrow import (
16+
ArrowStringArray,
17+
ArrowStringArrayNumpySemantics,
18+
)
1619

1720
skip_if_no_pyarrow = pytest.mark.skipif(
1821
pa_version_under7p0,
@@ -166,6 +169,9 @@ def test_pyarrow_not_installed_raises():
166169
with pytest.raises(ImportError, match=msg):
167170
ArrowStringArray([])
168171

172+
with pytest.raises(ImportError, match=msg):
173+
ArrowStringArrayNumpySemantics([])
174+
169175
with pytest.raises(ImportError, match=msg):
170176
ArrowStringArray._from_sequence(["a", None, "b"])
171177

pandas/tests/io/test_parquet.py

+19
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,25 @@ def test_roundtrip_decimal(self, tmp_path, pa):
11391139
expected = pd.DataFrame({"a": ["123"]}, dtype="string[python]")
11401140
tm.assert_frame_equal(result, expected)
11411141

1142+
def test_infer_string_large_string_type(self, tmp_path, pa):
1143+
# GH#54798
1144+
import pyarrow as pa
1145+
import pyarrow.parquet as pq
1146+
1147+
path = tmp_path / "large_string.p"
1148+
1149+
table = pa.table({"a": pa.array([None, "b", "c"], pa.large_string())})
1150+
pq.write_table(table, path)
1151+
1152+
with pd.option_context("future.infer_string", True):
1153+
result = read_parquet(path)
1154+
expected = pd.DataFrame(
1155+
data={"a": [None, "b", "c"]},
1156+
dtype="string[pyarrow_numpy]",
1157+
columns=pd.Index(["a"], dtype="string[pyarrow_numpy]"),
1158+
)
1159+
tm.assert_frame_equal(result, expected)
1160+
11421161

11431162
class TestParquetFastParquet(Base):
11441163
def test_basic(self, fp, df_full):

0 commit comments

Comments
 (0)