Skip to content

Commit 26e42d2

Browse files
committed
TST (string dtype): fix sql xfails with using_infer_string
1 parent 9b16b9e commit 26e42d2

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

pandas/core/dtypes/cast.py

+2
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,7 @@ def convert_dtypes(
11621162

11631163
def maybe_infer_to_datetimelike(
11641164
value: npt.NDArray[np.object_],
1165+
convert_to_nullable_dtype: bool = False,
11651166
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
11661167
"""
11671168
we might have a array (or single object) that is datetime like,
@@ -1199,6 +1200,7 @@ def maybe_infer_to_datetimelike(
11991200
# numpy would have done it for us.
12001201
convert_numeric=False,
12011202
convert_non_numeric=True,
1203+
convert_to_nullable_dtype=convert_to_nullable_dtype,
12021204
dtype_if_all_nat=np.dtype("M8[s]"),
12031205
)
12041206

pandas/core/internals/construction.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,9 @@ def convert(arr):
966966
if dtype is None:
967967
if arr.dtype == np.dtype("O"):
968968
# i.e. maybe_convert_objects didn't convert
969-
arr = maybe_infer_to_datetimelike(arr)
970-
if dtype_backend != "numpy" and arr.dtype == np.dtype("O"):
969+
convert_to_nullable_dtype = dtype_backend != "numpy"
970+
arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype)
971+
if convert_to_nullable_dtype and arr.dtype == np.dtype("O"):
971972
new_dtype = StringDtype()
972973
arr_cls = new_dtype.construct_array_type()
973974
arr = arr_cls._from_sequence(arr, dtype=new_dtype)

pandas/io/sql.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
Series,
5959
)
6060
from pandas.core.arrays import ArrowExtensionArray
61+
from pandas.core.arrays.string_ import StringDtype
6162
from pandas.core.base import PandasObject
6263
import pandas.core.common as com
6364
from pandas.core.common import maybe_make_list
@@ -177,6 +178,7 @@ def _convert_arrays_to_dataframe(
177178
pa_array = pa_array.cast(pa.string())
178179
result_arrays.append(ArrowExtensionArray(pa_array))
179180
arrays = result_arrays # type: ignore[assignment]
181+
180182
if arrays:
181183
return DataFrame._from_arrays(
182184
arrays, columns=columns, index=range(idx_len), verify_integrity=False
@@ -1316,11 +1318,11 @@ def _harmonize_columns(
13161318
elif dtype_backend == "numpy" and col_type is float:
13171319
# floats support NA, can always convert!
13181320
self.frame[col_name] = df_col.astype(col_type)
1319-
13201321
elif dtype_backend == "numpy" and len(df_col) == df_col.count():
13211322
# No NA values, can convert ints and bools
13221323
if col_type is np.dtype("int64") or col_type is bool:
13231324
self.frame[col_name] = df_col.astype(col_type)
1325+
13241326
except KeyError:
13251327
pass # this column not in results
13261328

@@ -1403,6 +1405,7 @@ def _get_dtype(self, sqltype):
14031405
DateTime,
14041406
Float,
14051407
Integer,
1408+
String,
14061409
)
14071410

14081411
if isinstance(sqltype, Float):
@@ -1422,6 +1425,10 @@ def _get_dtype(self, sqltype):
14221425
return date
14231426
elif isinstance(sqltype, Boolean):
14241427
return bool
1428+
elif isinstance(sqltype, String):
1429+
if using_string_dtype():
1430+
return StringDtype(na_value=np.nan)
1431+
14251432
return object
14261433

14271434

@@ -2205,7 +2212,7 @@ def read_table(
22052212
elif using_string_dtype():
22062213
from pandas.io._util import arrow_string_types_mapper
22072214

2208-
arrow_string_types_mapper()
2215+
mapping = arrow_string_types_mapper()
22092216
else:
22102217
mapping = None
22112218

@@ -2286,6 +2293,10 @@ def read_query(
22862293
from pandas.io._util import _arrow_dtype_mapping
22872294

22882295
mapping = _arrow_dtype_mapping().get
2296+
elif using_string_dtype():
2297+
from pandas.io._util import arrow_string_types_mapper
2298+
2299+
mapping = arrow_string_types_mapper()
22892300
else:
22902301
mapping = None
22912302

pandas/tests/io/test_sql.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import numpy as np
1919
import pytest
2020

21-
from pandas._config import using_string_dtype
22-
2321
from pandas._libs import lib
2422
from pandas.compat import pa_version_under14p1
2523
from pandas.compat._optional import import_optional_dependency
@@ -60,7 +58,6 @@
6058
pytest.mark.filterwarnings(
6159
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
6260
),
63-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
6461
]
6562

6663

@@ -3554,7 +3551,8 @@ def test_read_sql_dtype_backend(
35543551
result = getattr(pd, func)(
35553552
f"Select * from {table}", conn, dtype_backend=dtype_backend
35563553
)
3557-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3554+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3555+
35583556
tm.assert_frame_equal(result, expected)
35593557

35603558
if "adbc" in conn_name:
@@ -3604,7 +3602,7 @@ def test_read_sql_dtype_backend_table(
36043602

36053603
with pd.option_context("mode.string_storage", string_storage):
36063604
result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend)
3607-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3605+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
36083606
tm.assert_frame_equal(result, expected)
36093607

36103608
if "adbc" in conn_name:
@@ -4120,7 +4118,7 @@ def tquery(query, con=None):
41204118
def test_xsqlite_basic(sqlite_buildin):
41214119
frame = DataFrame(
41224120
np.random.default_rng(2).standard_normal((10, 4)),
4123-
columns=Index(list("ABCD"), dtype=object),
4121+
columns=Index(list("ABCD")),
41244122
index=date_range("2000-01-01", periods=10, freq="B"),
41254123
)
41264124
assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10
@@ -4147,7 +4145,7 @@ def test_xsqlite_basic(sqlite_buildin):
41474145
def test_xsqlite_write_row_by_row(sqlite_buildin):
41484146
frame = DataFrame(
41494147
np.random.default_rng(2).standard_normal((10, 4)),
4150-
columns=Index(list("ABCD"), dtype=object),
4148+
columns=Index(list("ABCD")),
41514149
index=date_range("2000-01-01", periods=10, freq="B"),
41524150
)
41534151
frame.iloc[0, 0] = np.nan
@@ -4170,7 +4168,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin):
41704168
def test_xsqlite_execute(sqlite_buildin):
41714169
frame = DataFrame(
41724170
np.random.default_rng(2).standard_normal((10, 4)),
4173-
columns=Index(list("ABCD"), dtype=object),
4171+
columns=Index(list("ABCD")),
41744172
index=date_range("2000-01-01", periods=10, freq="B"),
41754173
)
41764174
create_sql = sql.get_schema(frame, "test")
@@ -4191,7 +4189,7 @@ def test_xsqlite_execute(sqlite_buildin):
41914189
def test_xsqlite_schema(sqlite_buildin):
41924190
frame = DataFrame(
41934191
np.random.default_rng(2).standard_normal((10, 4)),
4194-
columns=Index(list("ABCD"), dtype=object),
4192+
columns=Index(list("ABCD")),
41954193
index=date_range("2000-01-01", periods=10, freq="B"),
41964194
)
41974195
create_sql = sql.get_schema(frame, "test")

0 commit comments

Comments
 (0)