Skip to content

Commit 487c585

Browse files
String dtype: fix pyarrow-based IO + update tests (#59478)
1 parent 328e79d commit 487c585

File tree

6 files changed

+79
-48
lines changed

6 files changed

+79
-48
lines changed

pandas/io/_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def _arrow_dtype_mapping() -> dict:
2727
pa.string(): pd.StringDtype(),
2828
pa.float32(): pd.Float32Dtype(),
2929
pa.float64(): pd.Float64Dtype(),
30+
pa.string(): pd.StringDtype(),
31+
pa.large_string(): pd.StringDtype(),
3032
}
3133

3234

pandas/tests/io/test_feather.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,15 @@
55
import numpy as np
66
import pytest
77

8-
from pandas._config import using_string_dtype
9-
108
import pandas as pd
119
import pandas._testing as tm
1210

1311
from pandas.io.feather_format import read_feather, to_feather # isort:skip
1412

15-
pytestmark = [
16-
pytest.mark.filterwarnings(
17-
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
18-
),
19-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
20-
]
13+
pytestmark = pytest.mark.filterwarnings(
14+
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
15+
)
16+
2117

2218
pa = pytest.importorskip("pyarrow")
2319

@@ -150,8 +146,8 @@ def test_path_pathlib(self):
150146
def test_passthrough_keywords(self):
151147
df = pd.DataFrame(
152148
1.1 * np.arange(120).reshape((30, 4)),
153-
columns=pd.Index(list("ABCD"), dtype=object),
154-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
149+
columns=pd.Index(list("ABCD")),
150+
index=pd.Index([f"i-{i}" for i in range(30)]),
155151
).reset_index()
156152
self.check_round_trip(df, write_kwargs={"version": 1})
157153

@@ -165,7 +161,9 @@ def test_http_path(self, feather_file, httpserver):
165161
res = read_feather(httpserver.url)
166162
tm.assert_frame_equal(expected, res)
167163

168-
def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
164+
def test_read_feather_dtype_backend(
165+
self, string_storage, dtype_backend, using_infer_string
166+
):
169167
# GH#50765
170168
df = pd.DataFrame(
171169
{
@@ -187,7 +185,10 @@ def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
187185

188186
if dtype_backend == "pyarrow":
189187
pa = pytest.importorskip("pyarrow")
190-
string_dtype = pd.ArrowDtype(pa.string())
188+
if using_infer_string:
189+
string_dtype = pd.ArrowDtype(pa.large_string())
190+
else:
191+
string_dtype = pd.ArrowDtype(pa.string())
191192
else:
192193
string_dtype = pd.StringDtype(string_storage)
193194

@@ -214,6 +215,10 @@ def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
214215
}
215216
)
216217

218+
if using_infer_string:
219+
expected.columns = expected.columns.astype(
220+
pd.StringDtype(string_storage, na_value=np.nan)
221+
)
217222
tm.assert_frame_equal(result, expected)
218223

219224
def test_int_columns_and_index(self):

pandas/tests/io/test_fsspec.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_excel_options(fsspectest):
176176
assert fsspectest.test[0] == "read"
177177

178178

179-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
179+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
180180
def test_to_parquet_new_file(cleared_fs, df1):
181181
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
182182
pytest.importorskip("fastparquet")
@@ -205,7 +205,7 @@ def test_arrowparquet_options(fsspectest):
205205
assert fsspectest.test[0] == "parquet_read"
206206

207207

208-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
208+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
209209
def test_fastparquet_options(fsspectest):
210210
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
211211
pytest.importorskip("fastparquet")
@@ -263,7 +263,7 @@ def test_s3_protocols(s3_public_bucket_with_data, tips_file, protocol, s3so):
263263
)
264264

265265

266-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
266+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
267267
@pytest.mark.single_cpu
268268
def test_s3_parquet(s3_public_bucket, s3so, df1):
269269
pytest.importorskip("fastparquet")

pandas/tests/io/test_gcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_to_csv_compression_encoding_gcs(
208208
tm.assert_frame_equal(df, read_df)
209209

210210

211-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
211+
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet")
212212
def test_to_parquet_gcs_new_file(monkeypatch, tmpdir):
213213
"""Regression test for writing to a not-yet-existent GCS Parquet file."""
214214
pytest.importorskip("fastparquet")

pandas/tests/io/test_orc.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import numpy as np
1010
import pytest
1111

12-
from pandas._config import using_string_dtype
13-
1412
import pandas as pd
1513
from pandas import read_orc
1614
import pandas._testing as tm
@@ -20,20 +18,17 @@
2018

2119
import pyarrow as pa
2220

23-
pytestmark = [
24-
pytest.mark.filterwarnings(
25-
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
26-
),
27-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
28-
]
21+
pytestmark = pytest.mark.filterwarnings(
22+
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
23+
)
2924

3025

3126
@pytest.fixture
3227
def dirpath(datapath):
3328
return datapath("io", "data", "orc")
3429

3530

36-
def test_orc_reader_empty(dirpath):
31+
def test_orc_reader_empty(dirpath, using_infer_string):
3732
columns = [
3833
"boolean1",
3934
"byte1",
@@ -54,11 +49,12 @@ def test_orc_reader_empty(dirpath):
5449
"float32",
5550
"float64",
5651
"object",
57-
"object",
52+
"str" if using_infer_string else "object",
5853
]
5954
expected = pd.DataFrame(index=pd.RangeIndex(0))
6055
for colname, dtype in zip(columns, dtypes):
6156
expected[colname] = pd.Series(dtype=dtype)
57+
expected.columns = expected.columns.astype("str")
6258

6359
inputfile = os.path.join(dirpath, "TestOrcFile.emptyFile.orc")
6460
got = read_orc(inputfile, columns=columns)
@@ -305,7 +301,7 @@ def test_orc_writer_dtypes_not_supported(orc_writer_dtypes_not_supported):
305301
df.to_orc()
306302

307303

308-
def test_orc_dtype_backend_pyarrow():
304+
def test_orc_dtype_backend_pyarrow(using_infer_string):
309305
pytest.importorskip("pyarrow")
310306
df = pd.DataFrame(
311307
{
@@ -338,6 +334,13 @@ def test_orc_dtype_backend_pyarrow():
338334
for col in df.columns
339335
}
340336
)
337+
if using_infer_string:
338+
# ORC does not preserve distinction between string and large string
339+
# -> the default large string comes back as string
340+
string_dtype = pd.ArrowDtype(pa.string())
341+
expected["string"] = expected["string"].astype(string_dtype)
342+
expected["string_with_nan"] = expected["string_with_nan"].astype(string_dtype)
343+
expected["string_with_none"] = expected["string_with_none"].astype(string_dtype)
341344

342345
tm.assert_frame_equal(result, expected)
343346

pandas/tests/io/test_parquet.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
pytest.mark.filterwarnings(
5353
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
5454
),
55-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
5655
]
5756

5857

@@ -61,10 +60,17 @@
6160
params=[
6261
pytest.param(
6362
"fastparquet",
64-
marks=pytest.mark.skipif(
65-
not _HAVE_FASTPARQUET,
66-
reason="fastparquet is not installed",
67-
),
63+
marks=[
64+
pytest.mark.skipif(
65+
not _HAVE_FASTPARQUET,
66+
reason="fastparquet is not installed",
67+
),
68+
pytest.mark.xfail(
69+
using_string_dtype(),
70+
reason="TODO(infer_string) fastparquet",
71+
strict=False,
72+
),
73+
],
6874
),
6975
pytest.param(
7076
"pyarrow",
@@ -86,15 +92,22 @@ def pa():
8692

8793

8894
@pytest.fixture
89-
def fp():
95+
def fp(request):
9096
if not _HAVE_FASTPARQUET:
9197
pytest.skip("fastparquet is not installed")
98+
if using_string_dtype():
99+
request.applymarker(
100+
pytest.mark.xfail(reason="TODO(infer_string) fastparquet", strict=False)
101+
)
92102
return "fastparquet"
93103

94104

95105
@pytest.fixture
96106
def df_compat():
97-
return pd.DataFrame({"A": [1, 2, 3], "B": "foo"})
107+
# TODO(infer_string) should this give str columns?
108+
return pd.DataFrame(
109+
{"A": [1, 2, 3], "B": "foo"}, columns=pd.Index(["A", "B"], dtype=object)
110+
)
98111

99112

100113
@pytest.fixture
@@ -366,16 +379,6 @@ def check_external_error_on_write(self, df, engine, exc):
366379
with tm.external_error_raised(exc):
367380
to_parquet(df, path, engine, compression=None)
368381

369-
@pytest.mark.network
370-
@pytest.mark.single_cpu
371-
def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine):
372-
if engine != "auto":
373-
pytest.importorskip(engine)
374-
with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f:
375-
httpserver.serve_content(content=f.read())
376-
df = read_parquet(httpserver.url)
377-
tm.assert_frame_equal(df, df_compat)
378-
379382

380383
class TestBasic(Base):
381384
def test_error(self, engine):
@@ -673,6 +676,16 @@ def test_read_empty_array(self, pa, dtype):
673676
df, pa, read_kwargs={"dtype_backend": "numpy_nullable"}, expected=expected
674677
)
675678

679+
@pytest.mark.network
680+
@pytest.mark.single_cpu
681+
def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine):
682+
if engine != "auto":
683+
pytest.importorskip(engine)
684+
with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f:
685+
httpserver.serve_content(content=f.read())
686+
df = read_parquet(httpserver.url, engine=engine)
687+
tm.assert_frame_equal(df, df_compat)
688+
676689

677690
class TestParquetPyArrow(Base):
678691
@pytest.mark.xfail(reason="datetime_with_nat unit doesn't round-trip")
@@ -906,7 +919,7 @@ def test_write_with_schema(self, pa):
906919
out_df = df.astype(bool)
907920
check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df)
908921

909-
def test_additional_extension_arrays(self, pa):
922+
def test_additional_extension_arrays(self, pa, using_infer_string):
910923
# test additional ExtensionArrays that are supported through the
911924
# __arrow_array__ protocol
912925
pytest.importorskip("pyarrow")
@@ -917,17 +930,25 @@ def test_additional_extension_arrays(self, pa):
917930
"c": pd.Series(["a", None, "c"], dtype="string"),
918931
}
919932
)
920-
check_round_trip(df, pa)
933+
if using_infer_string:
934+
check_round_trip(df, pa, expected=df.astype({"c": "str"}))
935+
else:
936+
check_round_trip(df, pa)
921937

922938
df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")})
923939
check_round_trip(df, pa)
924940

925-
def test_pyarrow_backed_string_array(self, pa, string_storage):
941+
def test_pyarrow_backed_string_array(self, pa, string_storage, using_infer_string):
926942
# test ArrowStringArray supported through the __arrow_array__ protocol
927943
pytest.importorskip("pyarrow")
928944
df = pd.DataFrame({"a": pd.Series(["a", None, "c"], dtype="string[pyarrow]")})
929945
with pd.option_context("string_storage", string_storage):
930-
check_round_trip(df, pa, expected=df.astype(f"string[{string_storage}]"))
946+
if using_infer_string:
947+
expected = df.astype("str")
948+
expected.columns = expected.columns.astype("str")
949+
else:
950+
expected = df.astype(f"string[{string_storage}]")
951+
check_round_trip(df, pa, expected=expected)
931952

932953
def test_additional_extension_types(self, pa):
933954
# test additional ExtensionArrays that are supported through the

0 commit comments

Comments
 (0)