Skip to content

Commit d746349

Browse files
phofljbrockmendel
authored andcommitted
ENH: Implement arrow string option for various I/O methods (pandas-dev#54431)
* ENH: Implement arrow string option for various I/O methods * ENH: allow opt-in to inferring pyarrow strings * Remove comments and add tests * Add string option to arrow parsers * Update * Update * Adjust csv * Update * Update * Add test * Fix mypy --------- Co-authored-by: Brock <[email protected]>
1 parent 83c26ac commit d746349

File tree

14 files changed

+134
-15
lines changed

14 files changed

+134
-15
lines changed

pandas/_config/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ def using_copy_on_write() -> bool:
3838
def using_nullable_dtypes() -> bool:
3939
_mode_options = _global_config["mode"]
4040
return _mode_options["nullable_dtypes"]
41+
42+
43+
def using_pyarrow_string_dtype() -> bool:
44+
_mode_options = _global_config["future"]
45+
return _mode_options["infer_string"]

pandas/_libs/lib.pyx

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ from cython cimport (
3838
floating,
3939
)
4040

41+
from pandas._config import using_pyarrow_string_dtype
42+
4143
from pandas._libs.missing import check_na_tuples_nonequal
4244

4345
import_datetime()
@@ -2679,9 +2681,7 @@ def maybe_convert_objects(ndarray[object] objects,
26792681

26802682
elif seen.str_:
26812683
if is_string_array(objects, skipna=True):
2682-
from pandas._config import get_option
2683-
opt = get_option("future.infer_string")
2684-
if opt is True:
2684+
if using_pyarrow_string_dtype():
26852685
import pyarrow as pa
26862686

26872687
from pandas.core.dtypes.dtypes import ArrowDtype

pandas/core/dtypes/cast.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020

21-
from pandas._config import get_option
21+
from pandas._config import using_pyarrow_string_dtype
2222

2323
from pandas._libs import lib
2424
from pandas._libs.missing import (
@@ -798,8 +798,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
798798
# coming out as np.str_!
799799

800800
dtype = _dtype_obj
801-
opt = get_option("future.infer_string")
802-
if opt is True:
801+
if using_pyarrow_string_dtype():
803802
import pyarrow as pa
804803

805804
pa_dtype = pa.string()

pandas/io/_util.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Callable
4+
35
from pandas.compat._optional import import_optional_dependency
46

57
import pandas as pd
@@ -21,3 +23,9 @@ def _arrow_dtype_mapping() -> dict:
2123
pa.float32(): pd.Float32Dtype(),
2224
pa.float64(): pd.Float64Dtype(),
2325
}
26+
27+
28+
def arrow_string_types_mapper() -> Callable:
29+
pa = import_optional_dependency("pyarrow")
30+
31+
return {pa.string(): pd.ArrowDtype(pa.string())}.get

pandas/io/feather_format.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Any,
77
)
88

9+
from pandas._config import using_pyarrow_string_dtype
10+
911
from pandas._libs import lib
1012
from pandas.compat._optional import import_optional_dependency
1113
from pandas.util._decorators import doc
@@ -15,6 +17,7 @@
1517
from pandas.core.api import DataFrame
1618
from pandas.core.shared_docs import _shared_docs
1719

20+
from pandas.io._util import arrow_string_types_mapper
1821
from pandas.io.common import get_handle
1922

2023
if TYPE_CHECKING:
@@ -119,7 +122,7 @@ def read_feather(
119122
with get_handle(
120123
path, "rb", storage_options=storage_options, is_text=False
121124
) as handles:
122-
if dtype_backend is lib.no_default:
125+
if dtype_backend is lib.no_default and not using_pyarrow_string_dtype():
123126
return feather.read_feather(
124127
handles.handle, columns=columns, use_threads=bool(use_threads)
125128
)
@@ -135,3 +138,8 @@ def read_feather(
135138

136139
elif dtype_backend == "pyarrow":
137140
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)
141+
142+
elif using_pyarrow_string_dtype():
143+
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
144+
else:
145+
raise NotImplementedError

pandas/io/orc.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
Literal,
1010
)
1111

12+
from pandas._config import using_pyarrow_string_dtype
13+
1214
from pandas._libs import lib
1315
from pandas.compat import pa_version_under8p0
1416
from pandas.compat._optional import import_optional_dependency
@@ -24,6 +26,7 @@
2426
import pandas as pd
2527
from pandas.core.indexes.api import default_index
2628

29+
from pandas.io._util import arrow_string_types_mapper
2730
from pandas.io.common import (
2831
get_handle,
2932
is_fsspec_url,
@@ -132,7 +135,11 @@ def read_orc(
132135
df = pa_table.to_pandas(types_mapper=mapping.get)
133136
return df
134137
else:
135-
return pa_table.to_pandas()
138+
if using_pyarrow_string_dtype():
139+
types_mapper = arrow_string_types_mapper()
140+
else:
141+
types_mapper = None
142+
return pa_table.to_pandas(types_mapper=types_mapper)
136143

137144

138145
def to_orc(

pandas/io/parquet.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import warnings
1313
from warnings import catch_warnings
1414

15+
from pandas._config import using_pyarrow_string_dtype
16+
1517
from pandas._libs import lib
1618
from pandas.compat._optional import import_optional_dependency
1719
from pandas.errors import AbstractMethodError
@@ -26,6 +28,7 @@
2628
)
2729
from pandas.core.shared_docs import _shared_docs
2830

31+
from pandas.io._util import arrow_string_types_mapper
2932
from pandas.io.common import (
3033
IOHandles,
3134
get_handle,
@@ -252,6 +255,8 @@ def read(
252255
to_pandas_kwargs["types_mapper"] = mapping.get
253256
elif dtype_backend == "pyarrow":
254257
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment] # noqa: E501
258+
elif using_pyarrow_string_dtype():
259+
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()
255260

256261
manager = get_option("mode.data_manager")
257262
if manager == "array":

pandas/io/parsers/arrow_parser_wrapper.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
from pandas._config import using_pyarrow_string_dtype
6+
57
from pandas._libs import lib
68
from pandas.compat._optional import import_optional_dependency
79

@@ -10,7 +12,10 @@
1012
import pandas as pd
1113
from pandas import DataFrame
1214

13-
from pandas.io._util import _arrow_dtype_mapping
15+
from pandas.io._util import (
16+
_arrow_dtype_mapping,
17+
arrow_string_types_mapper,
18+
)
1419
from pandas.io.parsers.base_parser import ParserBase
1520

1621
if TYPE_CHECKING:
@@ -215,6 +220,8 @@ def read(self) -> DataFrame:
215220
dtype_mapping = _arrow_dtype_mapping()
216221
dtype_mapping[pa.null()] = pd.Int64Dtype()
217222
frame = table.to_pandas(types_mapper=dtype_mapping.get)
223+
elif using_pyarrow_string_dtype():
224+
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
218225
else:
219226
frame = table.to_pandas()
220227
return self._finalize_pandas_output(frame)

pandas/io/pytables.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from pandas._config import (
3131
config,
3232
get_option,
33+
using_pyarrow_string_dtype,
3334
)
3435

3536
from pandas._libs import (
3637
lib,
3738
writers as libwriters,
3839
)
40+
from pandas._libs.lib import is_string_array
3941
from pandas._libs.tslibs import timezones
4042
from pandas.compat._optional import import_optional_dependency
4143
from pandas.compat.pickle_compat import patch_pickle
@@ -66,6 +68,7 @@
6668
)
6769
from pandas.core.dtypes.missing import array_equivalent
6870

71+
import pandas as pd
6972
from pandas import (
7073
DataFrame,
7174
DatetimeIndex,
@@ -3219,7 +3222,12 @@ def read(
32193222
self.validate_read(columns, where)
32203223
index = self.read_index("index", start=start, stop=stop)
32213224
values = self.read_array("values", start=start, stop=stop)
3222-
return Series(values, index=index, name=self.name, copy=False)
3225+
result = Series(values, index=index, name=self.name, copy=False)
3226+
if using_pyarrow_string_dtype() and is_string_array(values, skipna=True):
3227+
import pyarrow as pa
3228+
3229+
result = result.astype(pd.ArrowDtype(pa.string()))
3230+
return result
32233231

32243232
# error: Signature of "write" incompatible with supertype "Fixed"
32253233
def write(self, obj, **kwargs) -> None: # type: ignore[override]
@@ -3287,6 +3295,10 @@ def read(
32873295

32883296
columns = items[items.get_indexer(blk_items)]
32893297
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
3298+
if using_pyarrow_string_dtype() and is_string_array(values, skipna=True):
3299+
import pyarrow as pa
3300+
3301+
df = df.astype(pd.ArrowDtype(pa.string()))
32903302
dfs.append(df)
32913303

32923304
if len(dfs) > 0:
@@ -4668,7 +4680,15 @@ def read(
46684680
else:
46694681
# Categorical
46704682
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
4671-
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4683+
if not (using_pyarrow_string_dtype() and values.dtype.kind == "O"):
4684+
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4685+
if using_pyarrow_string_dtype() and is_string_array(
4686+
values, # type: ignore[arg-type]
4687+
skipna=True,
4688+
):
4689+
import pyarrow as pa
4690+
4691+
df = df.astype(pd.ArrowDtype(pa.string()))
46724692
frames.append(df)
46734693

46744694
if len(frames) == 1:

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -547,15 +547,14 @@ def test_string_inference(all_parsers):
547547

548548
data = """a,b
549549
x,1
550-
y,2"""
550+
y,2
551+
,3"""
551552
parser = all_parsers
552-
if parser.engine == "pyarrow":
553-
pytest.skip("TODO: Follow up")
554553
with pd.option_context("future.infer_string", True):
555554
result = parser.read_csv(StringIO(data))
556555

557556
expected = DataFrame(
558-
{"a": pd.Series(["x", "y"], dtype=dtype), "b": [1, 2]},
557+
{"a": pd.Series(["x", "y", None], dtype=dtype), "b": [1, 2, 3]},
559558
columns=pd.Index(["a", "b"], dtype=dtype),
560559
)
561560
tm.assert_frame_equal(result, expected)

pandas/tests/io/pytables/test_read.py

+16
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,19 @@ def test_read_py2_hdf_file_in_py3(datapath):
388388
) as store:
389389
result = store["p"]
390390
tm.assert_frame_equal(result, expected)
391+
392+
393+
def test_read_infer_string(tmp_path, setup_path):
394+
# GH#54431
395+
pa = pytest.importorskip("pyarrow")
396+
df = DataFrame({"a": ["a", "b", None]})
397+
path = tmp_path / setup_path
398+
df.to_hdf(path, key="data", format="table")
399+
with pd.option_context("future.infer_string", True):
400+
result = read_hdf(path, key="data", mode="r")
401+
expected = DataFrame(
402+
{"a": ["a", "b", None]},
403+
dtype=pd.ArrowDtype(pa.string()),
404+
columns=Index(["a"], dtype=pd.ArrowDtype(pa.string())),
405+
)
406+
tm.assert_frame_equal(result, expected)

pandas/tests/io/test_feather.py

+14
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,17 @@ def test_invalid_dtype_backend(self):
219219
df.to_feather(path)
220220
with pytest.raises(ValueError, match=msg):
221221
read_feather(path, dtype_backend="numpy")
222+
223+
def test_string_inference(self, tmp_path):
224+
# GH#54431
225+
import pyarrow as pa
226+
227+
path = tmp_path / "test_string_inference.p"
228+
df = pd.DataFrame(data={"a": ["x", "y"]})
229+
df.to_feather(path)
230+
with pd.option_context("future.infer_string", True):
231+
result = read_feather(path)
232+
expected = pd.DataFrame(
233+
data={"a": ["x", "y"]}, dtype=pd.ArrowDtype(pa.string())
234+
)
235+
tm.assert_frame_equal(result, expected)

pandas/tests/io/test_orc.py

+15
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,18 @@ def test_invalid_dtype_backend():
415415
df.to_orc(path)
416416
with pytest.raises(ValueError, match=msg):
417417
read_orc(path, dtype_backend="numpy")
418+
419+
420+
def test_string_inference(tmp_path):
421+
# GH#54431
422+
path = tmp_path / "test_string_inference.p"
423+
df = pd.DataFrame(data={"a": ["x", "y"]})
424+
df.to_orc(path)
425+
with pd.option_context("future.infer_string", True):
426+
result = read_orc(path)
427+
expected = pd.DataFrame(
428+
data={"a": ["x", "y"]},
429+
dtype=pd.ArrowDtype(pa.string()),
430+
columns=pd.Index(["a"], dtype=pd.ArrowDtype(pa.string())),
431+
)
432+
tm.assert_frame_equal(result, expected)

pandas/tests/io/test_parquet.py

+16
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,22 @@ def test_df_attrs_persistence(self, tmp_path, pa):
11031103
new_df = read_parquet(path, engine=pa)
11041104
assert new_df.attrs == df.attrs
11051105

1106+
def test_string_inference(self, tmp_path, pa):
1107+
# GH#54431
1108+
import pyarrow as pa
1109+
1110+
path = tmp_path / "test_string_inference.p"
1111+
df = pd.DataFrame(data={"a": ["x", "y"]}, index=["a", "b"])
1112+
df.to_parquet(path, engine="pyarrow")
1113+
with pd.option_context("future.infer_string", True):
1114+
result = read_parquet(path, engine="pyarrow")
1115+
expected = pd.DataFrame(
1116+
data={"a": ["x", "y"]},
1117+
dtype=pd.ArrowDtype(pa.string()),
1118+
index=pd.Index(["a", "b"], dtype=pd.ArrowDtype(pa.string())),
1119+
)
1120+
tm.assert_frame_equal(result, expected)
1121+
11061122

11071123
class TestParquetFastParquet(Base):
11081124
def test_basic(self, fp, df_full):

0 commit comments

Comments
 (0)