Skip to content

Commit 379feea

Browse files
committed
TST (string dtype): fix sql xfails with using_infer_string
1 parent 73da90c commit 379feea

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

pandas/core/dtypes/cast.py

+2
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,7 @@ def convert_dtypes(
11621162

11631163
def maybe_infer_to_datetimelike(
11641164
value: npt.NDArray[np.object_],
1165+
convert_to_nullable_dtype: bool = False,
11651166
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
11661167
"""
11671168
we might have a array (or single object) that is datetime like,
@@ -1199,6 +1200,7 @@ def maybe_infer_to_datetimelike(
11991200
# numpy would have done it for us.
12001201
convert_numeric=False,
12011202
convert_non_numeric=True,
1203+
convert_to_nullable_dtype=convert_to_nullable_dtype,
12021204
dtype_if_all_nat=np.dtype("M8[s]"),
12031205
)
12041206

pandas/core/internals/construction.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,9 @@ def convert(arr):
966966
if dtype is None:
967967
if arr.dtype == np.dtype("O"):
968968
# i.e. maybe_convert_objects didn't convert
969-
arr = maybe_infer_to_datetimelike(arr)
970-
if dtype_backend != "numpy" and arr.dtype == np.dtype("O"):
969+
convert_to_nullable_dtype = dtype_backend != "numpy"
970+
arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype)
971+
if convert_to_nullable_dtype and arr.dtype == np.dtype("O"):
971972
new_dtype = StringDtype()
972973
arr_cls = new_dtype.construct_array_type()
973974
arr = arr_cls._from_sequence(arr, dtype=new_dtype)

pandas/io/sql.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from pandas.core.dtypes.common import (
4646
is_dict_like,
4747
is_list_like,
48+
is_object_dtype,
49+
is_string_dtype,
4850
)
4951
from pandas.core.dtypes.dtypes import (
5052
ArrowDtype,
@@ -58,6 +60,7 @@
5860
Series,
5961
)
6062
from pandas.core.arrays import ArrowExtensionArray
63+
from pandas.core.arrays.string_ import StringDtype
6164
from pandas.core.base import PandasObject
6265
import pandas.core.common as com
6366
from pandas.core.common import maybe_make_list
@@ -1316,7 +1319,12 @@ def _harmonize_columns(
13161319
elif dtype_backend == "numpy" and col_type is float:
13171320
# floats support NA, can always convert!
13181321
self.frame[col_name] = df_col.astype(col_type)
1319-
1322+
elif (
1323+
using_string_dtype()
1324+
and is_string_dtype(col_type)
1325+
and is_object_dtype(self.frame[col_name])
1326+
):
1327+
self.frame[col_name] = df_col.astype(col_type)
13201328
elif dtype_backend == "numpy" and len(df_col) == df_col.count():
13211329
# No NA values, can convert ints and bools
13221330
if col_type is np.dtype("int64") or col_type is bool:
@@ -1403,6 +1411,7 @@ def _get_dtype(self, sqltype):
14031411
DateTime,
14041412
Float,
14051413
Integer,
1414+
String,
14061415
)
14071416

14081417
if isinstance(sqltype, Float):
@@ -1422,6 +1431,10 @@ def _get_dtype(self, sqltype):
14221431
return date
14231432
elif isinstance(sqltype, Boolean):
14241433
return bool
1434+
elif isinstance(sqltype, String):
1435+
if using_string_dtype():
1436+
return StringDtype(na_value=np.nan)
1437+
14251438
return object
14261439

14271440

@@ -2205,7 +2218,7 @@ def read_table(
22052218
elif using_string_dtype():
22062219
from pandas.io._util import arrow_string_types_mapper
22072220

2208-
arrow_string_types_mapper()
2221+
mapping = arrow_string_types_mapper()
22092222
else:
22102223
mapping = None
22112224

@@ -2286,6 +2299,10 @@ def read_query(
22862299
from pandas.io._util import _arrow_dtype_mapping
22872300

22882301
mapping = _arrow_dtype_mapping().get
2302+
elif using_string_dtype():
2303+
from pandas.io._util import arrow_string_types_mapper
2304+
2305+
mapping = arrow_string_types_mapper()
22892306
else:
22902307
mapping = None
22912308

pandas/tests/io/test_sql.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
pytest.mark.filterwarnings(
6161
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
6262
),
63-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
6463
]
6564

6665

@@ -685,6 +684,7 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
685684

686685
@pytest.fixture
687686
def postgresql_adbc_conn():
687+
pytest.importorskip("pyarrow")
688688
pytest.importorskip("adbc_driver_postgresql")
689689
from adbc_driver_postgresql import dbapi
690690

@@ -817,6 +817,7 @@ def sqlite_conn_types(sqlite_engine_types):
817817

818818
@pytest.fixture
819819
def sqlite_adbc_conn():
820+
pytest.importorskip("pyarrow")
820821
pytest.importorskip("adbc_driver_sqlite")
821822
from adbc_driver_sqlite import dbapi
822823

@@ -986,13 +987,13 @@ def test_dataframe_to_sql(conn, test_frame1, request):
986987

987988
@pytest.mark.parametrize("conn", all_connectable)
988989
def test_dataframe_to_sql_empty(conn, test_frame1, request):
989-
if conn == "postgresql_adbc_conn":
990+
if conn == "postgresql_adbc_conn" and not using_string_dtype():
990991
request.node.add_marker(
991992
pytest.mark.xfail(
992-
reason="postgres ADBC driver cannot insert index with null type",
993-
strict=True,
993+
reason="postgres ADBC driver < 1.2 cannot insert index with null type",
994994
)
995995
)
996+
996997
# GH 51086 if conn is sqlite_engine
997998
conn = request.getfixturevalue(conn)
998999
empty_df = test_frame1.iloc[:0]
@@ -3557,7 +3558,8 @@ def test_read_sql_dtype_backend(
35573558
result = getattr(pd, func)(
35583559
f"Select * from {table}", conn, dtype_backend=dtype_backend
35593560
)
3560-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3561+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3562+
35613563
tm.assert_frame_equal(result, expected)
35623564

35633565
if "adbc" in conn_name:
@@ -3607,7 +3609,7 @@ def test_read_sql_dtype_backend_table(
36073609

36083610
with pd.option_context("mode.string_storage", string_storage):
36093611
result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend)
3610-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3612+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
36113613
tm.assert_frame_equal(result, expected)
36123614

36133615
if "adbc" in conn_name:
@@ -4123,7 +4125,7 @@ def tquery(query, con=None):
41234125
def test_xsqlite_basic(sqlite_buildin):
41244126
frame = DataFrame(
41254127
np.random.default_rng(2).standard_normal((10, 4)),
4126-
columns=Index(list("ABCD"), dtype=object),
4128+
columns=Index(list("ABCD")),
41274129
index=date_range("2000-01-01", periods=10, freq="B"),
41284130
)
41294131
assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10
@@ -4150,7 +4152,7 @@ def test_xsqlite_basic(sqlite_buildin):
41504152
def test_xsqlite_write_row_by_row(sqlite_buildin):
41514153
frame = DataFrame(
41524154
np.random.default_rng(2).standard_normal((10, 4)),
4153-
columns=Index(list("ABCD"), dtype=object),
4155+
columns=Index(list("ABCD")),
41544156
index=date_range("2000-01-01", periods=10, freq="B"),
41554157
)
41564158
frame.iloc[0, 0] = np.nan
@@ -4173,7 +4175,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin):
41734175
def test_xsqlite_execute(sqlite_buildin):
41744176
frame = DataFrame(
41754177
np.random.default_rng(2).standard_normal((10, 4)),
4176-
columns=Index(list("ABCD"), dtype=object),
4178+
columns=Index(list("ABCD")),
41774179
index=date_range("2000-01-01", periods=10, freq="B"),
41784180
)
41794181
create_sql = sql.get_schema(frame, "test")
@@ -4194,7 +4196,7 @@ def test_xsqlite_execute(sqlite_buildin):
41944196
def test_xsqlite_schema(sqlite_buildin):
41954197
frame = DataFrame(
41964198
np.random.default_rng(2).standard_normal((10, 4)),
4197-
columns=Index(list("ABCD"), dtype=object),
4199+
columns=Index(list("ABCD")),
41984200
index=date_range("2000-01-01", periods=10, freq="B"),
41994201
)
42004202
create_sql = sql.get_schema(frame, "test")

0 commit comments

Comments
 (0)