Skip to content

Commit 79ce72b

Browse files
phofljbrockmendelmroeschke
authored
ENH: allow opt-in to inferring pyarrow strings (#54430)
* ENH: allow opt-in to inferring pyarrow strings * Remove comments and add tests * Add json tests * Update * Update pandas/_libs/lib.pyx Co-authored-by: Matthew Roeschke <[email protected]> * Update * Add test --------- Co-authored-by: Brock <[email protected]> Co-authored-by: Matthew Roeschke <[email protected]>
1 parent cf741a4 commit 79ce72b

File tree

9 files changed

+178
-0
lines changed

9 files changed

+178
-0
lines changed

pandas/_libs/lib.pyx

+22
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ cdef class Seen:
12991299
bint datetimetz_ # seen_datetimetz
13001300
bint period_ # seen_period
13011301
bint interval_ # seen_interval
1302+
bint str_ # seen_str
13021303

13031304
def __cinit__(self, bint coerce_numeric=False):
13041305
"""
@@ -1325,6 +1326,7 @@ cdef class Seen:
13251326
self.datetimetz_ = False
13261327
self.period_ = False
13271328
self.interval_ = False
1329+
self.str_ = False
13281330
self.coerce_numeric = coerce_numeric
13291331

13301332
cdef bint check_uint64_conflict(self) except -1:
@@ -2615,6 +2617,13 @@ def maybe_convert_objects(ndarray[object] objects,
26152617
else:
26162618
seen.object_ = True
26172619
break
2620+
elif isinstance(val, str):
2621+
if convert_non_numeric:
2622+
seen.str_ = True
2623+
break
2624+
else:
2625+
seen.object_ = True
2626+
break
26182627
else:
26192628
seen.object_ = True
26202629
break
@@ -2669,6 +2678,19 @@ def maybe_convert_objects(ndarray[object] objects,
26692678
return pi._data
26702679
seen.object_ = True
26712680

2681+
elif seen.str_:
2682+
if is_string_array(objects, skipna=True):
2683+
from pandas._config import get_option
2684+
opt = get_option("future.infer_string")
2685+
if opt is True:
2686+
import pyarrow as pa
2687+
2688+
from pandas.core.dtypes.dtypes import ArrowDtype
2689+
2690+
dtype = ArrowDtype(pa.string())
2691+
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
2692+
2693+
seen.object_ = True
26722694
elif seen.interval_:
26732695
if is_interval_array(objects):
26742696
from pandas import IntervalIndex

pandas/core/config_init.py

+11
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,14 @@ def register_converter_cb(key) -> None:
889889
styler_environment,
890890
validator=is_instance_factory([type(None), str]),
891891
)
892+
893+
894+
with cf.config_prefix("future"):
895+
cf.register_option(
896+
"infer_string",
897+
False,
898+
"Whether to infer sequence of str objects as pyarrow string "
899+
"dtype, which will be the default in pandas 3.0 "
900+
"(at which point this option will be deprecated).",
901+
validator=is_one_of_factory([True, False]),
902+
)

pandas/core/dtypes/cast.py

+8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import numpy as np
2020

21+
from pandas._config import get_option
22+
2123
from pandas._libs import lib
2224
from pandas._libs.missing import (
2325
NA,
@@ -796,6 +798,12 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
796798
# coming out as np.str_!
797799

798800
dtype = _dtype_obj
801+
opt = get_option("future.infer_string")
802+
if opt is True:
803+
import pyarrow as pa
804+
805+
pa_dtype = pa.string()
806+
dtype = ArrowDtype(pa_dtype)
799807

800808
elif isinstance(val, (np.datetime64, dt.datetime)):
801809
try:

pandas/tests/frame/test_constructors.py

+35
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,41 @@ def test_construct_with_strings_and_none(self):
26852685
expected = DataFrame({"a": ["1", "2", None]}, dtype="str")
26862686
tm.assert_frame_equal(df, expected)
26872687

2688+
def test_frame_string_inference(self):
2689+
# GH#54430
2690+
pa = pytest.importorskip("pyarrow")
2691+
dtype = pd.ArrowDtype(pa.string())
2692+
expected = DataFrame(
2693+
{"a": ["a", "b"]}, dtype=dtype, columns=Index(["a"], dtype=dtype)
2694+
)
2695+
with pd.option_context("future.infer_string", True):
2696+
df = DataFrame({"a": ["a", "b"]})
2697+
tm.assert_frame_equal(df, expected)
2698+
2699+
expected = DataFrame(
2700+
{"a": ["a", "b"]},
2701+
dtype=dtype,
2702+
columns=Index(["a"], dtype=dtype),
2703+
index=Index(["x", "y"], dtype=dtype),
2704+
)
2705+
with pd.option_context("future.infer_string", True):
2706+
df = DataFrame({"a": ["a", "b"]}, index=["x", "y"])
2707+
tm.assert_frame_equal(df, expected)
2708+
2709+
expected = DataFrame(
2710+
{"a": ["a", 1]}, dtype="object", columns=Index(["a"], dtype=dtype)
2711+
)
2712+
with pd.option_context("future.infer_string", True):
2713+
df = DataFrame({"a": ["a", 1]})
2714+
tm.assert_frame_equal(df, expected)
2715+
2716+
expected = DataFrame(
2717+
{"a": ["a", "b"]}, dtype="object", columns=Index(["a"], dtype=dtype)
2718+
)
2719+
with pd.option_context("future.infer_string", True):
2720+
df = DataFrame({"a": ["a", "b"]}, dtype="object")
2721+
tm.assert_frame_equal(df, expected)
2722+
26882723

26892724
class TestDataFrameConstructorIndexInference:
26902725
def test_frame_from_dict_of_series_overlapping_monthly_period_indexes(self):

pandas/tests/indexes/base_class/test_constructors.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pandas as pd
45
from pandas import (
56
Index,
67
MultiIndex,
@@ -42,3 +43,17 @@ def test_construct_empty_tuples(self, tuple_list):
4243
expected = MultiIndex.from_tuples(tuple_list)
4344

4445
tm.assert_index_equal(result, expected)
46+
47+
def test_index_string_inference(self):
48+
# GH#54430
49+
pa = pytest.importorskip("pyarrow")
50+
dtype = pd.ArrowDtype(pa.string())
51+
expected = Index(["a", "b"], dtype=dtype)
52+
with pd.option_context("future.infer_string", True):
53+
ser = Index(["a", "b"])
54+
tm.assert_index_equal(ser, expected)
55+
56+
expected = Index(["a", 1], dtype="object")
57+
with pd.option_context("future.infer_string", True):
58+
ser = Index(["a", 1])
59+
tm.assert_index_equal(ser, expected)

pandas/tests/io/json/test_pandas.py

+17
Original file line numberDiff line numberDiff line change
@@ -2096,3 +2096,20 @@ def test_pyarrow_engine_lines_false():
20962096
out = ser.to_json()
20972097
with pytest.raises(ValueError, match="currently pyarrow engine only supports"):
20982098
read_json(out, engine="pyarrow", lines=False)
2099+
2100+
2101+
def test_json_roundtrip_string_inference(orient):
2102+
pa = pytest.importorskip("pyarrow")
2103+
df = DataFrame(
2104+
[["a", "b"], ["c", "d"]], index=["row 1", "row 2"], columns=["col 1", "col 2"]
2105+
)
2106+
out = df.to_json()
2107+
with pd.option_context("future.infer_string", True):
2108+
result = read_json(StringIO(out))
2109+
expected = DataFrame(
2110+
[["a", "b"], ["c", "d"]],
2111+
dtype=pd.ArrowDtype(pa.string()),
2112+
index=pd.Index(["row 1", "row 2"], dtype=pd.ArrowDtype(pa.string())),
2113+
columns=pd.Index(["col 1", "col 2"], dtype=pd.ArrowDtype(pa.string())),
2114+
)
2115+
tm.assert_frame_equal(result, expected)

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

+21
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,24 @@ def test_ea_int_avoid_overflow(all_parsers):
538538
}
539539
)
540540
tm.assert_frame_equal(result, expected)
541+
542+
543+
def test_string_inference(all_parsers):
544+
# GH#54430
545+
pa = pytest.importorskip("pyarrow")
546+
dtype = pd.ArrowDtype(pa.string())
547+
548+
data = """a,b
549+
x,1
550+
y,2"""
551+
parser = all_parsers
552+
if parser.engine == "pyarrow":
553+
pytest.skip("TODO: Follow up")
554+
with pd.option_context("future.infer_string", True):
555+
result = parser.read_csv(StringIO(data))
556+
557+
expected = DataFrame(
558+
{"a": pd.Series(["x", "y"], dtype=dtype), "b": [1, 2]},
559+
columns=pd.Index(["a", "b"], dtype=dtype),
560+
)
561+
tm.assert_frame_equal(result, expected)

pandas/tests/io/test_sql.py

+17
Original file line numberDiff line numberDiff line change
@@ -2922,6 +2922,23 @@ def test_read_sql_dtype_backend_table(self, string_storage, func):
29222922
# GH#50048 Not supported for sqlite
29232923
pass
29242924

2925+
def test_read_sql_string_inference(self):
2926+
# GH#54430
2927+
pa = pytest.importorskip("pyarrow")
2928+
table = "test"
2929+
df = DataFrame({"a": ["x", "y"]})
2930+
df.to_sql(table, self.conn, index=False, if_exists="replace")
2931+
2932+
with pd.option_context("future.infer_string", True):
2933+
result = read_sql_table(table, self.conn)
2934+
2935+
dtype = pd.ArrowDtype(pa.string())
2936+
expected = DataFrame(
2937+
{"a": ["x", "y"]}, dtype=dtype, columns=Index(["a"], dtype=dtype)
2938+
)
2939+
2940+
tm.assert_frame_equal(result, expected)
2941+
29252942

29262943
@pytest.mark.db
29272944
class TestMySQLAlchemy(_TestSQLAlchemy):

pandas/tests/series/test_constructors.py

+32
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,38 @@ def test_series_from_index_dtype_equal_does_not_copy(self):
20752075
ser.iloc[0] = 100
20762076
tm.assert_index_equal(idx, expected)
20772077

2078+
def test_series_string_inference(self):
2079+
# GH#54430
2080+
pa = pytest.importorskip("pyarrow")
2081+
dtype = pd.ArrowDtype(pa.string())
2082+
expected = Series(["a", "b"], dtype=dtype)
2083+
with pd.option_context("future.infer_string", True):
2084+
ser = Series(["a", "b"])
2085+
tm.assert_series_equal(ser, expected)
2086+
2087+
expected = Series(["a", 1], dtype="object")
2088+
with pd.option_context("future.infer_string", True):
2089+
ser = Series(["a", 1])
2090+
tm.assert_series_equal(ser, expected)
2091+
2092+
@pytest.mark.parametrize("na_value", [None, np.nan, pd.NA])
2093+
def test_series_string_with_na_inference(self, na_value):
2094+
# GH#54430
2095+
pa = pytest.importorskip("pyarrow")
2096+
dtype = pd.ArrowDtype(pa.string())
2097+
expected = Series(["a", na_value], dtype=dtype)
2098+
with pd.option_context("future.infer_string", True):
2099+
ser = Series(["a", na_value])
2100+
tm.assert_series_equal(ser, expected)
2101+
2102+
def test_series_string_inference_scalar(self):
2103+
# GH#54430
2104+
pa = pytest.importorskip("pyarrow")
2105+
expected = Series("a", index=[1], dtype=pd.ArrowDtype(pa.string()))
2106+
with pd.option_context("future.infer_string", True):
2107+
ser = Series("a", index=[1])
2108+
tm.assert_series_equal(ser, expected)
2109+
20782110

20792111
class TestSeriesConstructorIndexCoercion:
20802112
def test_series_constructor_datetimelike_index_coercion(self):

0 commit comments

Comments
 (0)