Skip to content

Commit 2304e71

Browse files
WillAydjorisvandenbossche
authored andcommitted
String dtype: enable in SQL IO + resolve all xfails (pandas-dev#60255)
(cherry picked from commit ba4d1cf)
1 parent 2054463 commit 2304e71

File tree

4 files changed

+38
-15
lines changed

4 files changed

+38
-15
lines changed

pandas/core/dtypes/cast.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ def convert_dtypes(
11631163

11641164
def maybe_infer_to_datetimelike(
11651165
value: npt.NDArray[np.object_],
1166+
convert_to_nullable_dtype: bool = False,
11661167
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
11671168
"""
11681169
we might have a array (or single object) that is datetime like,
@@ -1200,7 +1201,8 @@ def maybe_infer_to_datetimelike(
12001201
# numpy would have done it for us.
12011202
convert_numeric=False,
12021203
convert_non_numeric=True,
1203-
dtype_if_all_nat=np.dtype("M8[ns]"),
1204+
convert_to_nullable_dtype=convert_to_nullable_dtype,
1205+
dtype_if_all_nat=np.dtype("M8[s]"),
12041206
)
12051207

12061208

pandas/core/internals/construction.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1042,8 +1042,9 @@ def convert(arr):
10421042
if dtype is None:
10431043
if arr.dtype == np.dtype("O"):
10441044
# i.e. maybe_convert_objects didn't convert
1045-
arr = maybe_infer_to_datetimelike(arr)
1046-
if dtype_backend != "numpy" and arr.dtype == np.dtype("O"):
1045+
convert_to_nullable_dtype = dtype_backend != "numpy"
1046+
arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype)
1047+
if convert_to_nullable_dtype and arr.dtype == np.dtype("O"):
10471048
new_dtype = StringDtype()
10481049
arr_cls = new_dtype.construct_array_type()
10491050
arr = arr_cls._from_sequence(arr, dtype=new_dtype)

pandas/io/sql.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from pandas.core.dtypes.common import (
4747
is_dict_like,
4848
is_list_like,
49+
is_object_dtype,
50+
is_string_dtype,
4951
)
5052
from pandas.core.dtypes.dtypes import (
5153
ArrowDtype,
@@ -59,6 +61,7 @@
5961
Series,
6062
)
6163
from pandas.core.arrays import ArrowExtensionArray
64+
from pandas.core.arrays.string_ import StringDtype
6265
from pandas.core.base import PandasObject
6366
import pandas.core.common as com
6467
from pandas.core.common import maybe_make_list
@@ -1331,7 +1334,12 @@ def _harmonize_columns(
13311334
elif dtype_backend == "numpy" and col_type is float:
13321335
# floats support NA, can always convert!
13331336
self.frame[col_name] = df_col.astype(col_type, copy=False)
1334-
1337+
elif (
1338+
using_string_dtype()
1339+
and is_string_dtype(col_type)
1340+
and is_object_dtype(self.frame[col_name])
1341+
):
1342+
self.frame[col_name] = df_col.astype(col_type, copy=False)
13351343
elif dtype_backend == "numpy" and len(df_col) == df_col.count():
13361344
# No NA values, can convert ints and bools
13371345
if col_type is np.dtype("int64") or col_type is bool:
@@ -1418,6 +1426,7 @@ def _get_dtype(self, sqltype):
14181426
DateTime,
14191427
Float,
14201428
Integer,
1429+
String,
14211430
)
14221431

14231432
if isinstance(sqltype, Float):
@@ -1437,6 +1446,10 @@ def _get_dtype(self, sqltype):
14371446
return date
14381447
elif isinstance(sqltype, Boolean):
14391448
return bool
1449+
elif isinstance(sqltype, String):
1450+
if using_string_dtype():
1451+
return StringDtype(na_value=np.nan)
1452+
14401453
return object
14411454

14421455

@@ -2218,7 +2231,7 @@ def read_table(
22182231
elif using_string_dtype():
22192232
from pandas.io._util import arrow_string_types_mapper
22202233

2221-
arrow_string_types_mapper()
2234+
mapping = arrow_string_types_mapper()
22222235
else:
22232236
mapping = None
22242237

@@ -2299,6 +2312,10 @@ def read_query(
22992312
from pandas.io._util import _arrow_dtype_mapping
23002313

23012314
mapping = _arrow_dtype_mapping().get
2315+
elif using_string_dtype():
2316+
from pandas.io._util import arrow_string_types_mapper
2317+
2318+
mapping = arrow_string_types_mapper()
23022319
else:
23032320
mapping = None
23042321

pandas/tests/io/test_sql.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
pytest.mark.filterwarnings(
6464
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
6565
),
66-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
66+
pytest.mark.single_cpu,
6767
]
6868

6969

@@ -685,6 +685,7 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
685685

686686
@pytest.fixture
687687
def postgresql_adbc_conn():
688+
pytest.importorskip("pyarrow")
688689
pytest.importorskip("adbc_driver_postgresql")
689690
from adbc_driver_postgresql import dbapi
690691

@@ -817,6 +818,7 @@ def sqlite_conn_types(sqlite_engine_types):
817818

818819
@pytest.fixture
819820
def sqlite_adbc_conn():
821+
pytest.importorskip("pyarrow")
820822
pytest.importorskip("adbc_driver_sqlite")
821823
from adbc_driver_sqlite import dbapi
822824

@@ -986,13 +988,13 @@ def test_dataframe_to_sql(conn, test_frame1, request):
986988

987989
@pytest.mark.parametrize("conn", all_connectable)
988990
def test_dataframe_to_sql_empty(conn, test_frame1, request):
989-
if conn == "postgresql_adbc_conn":
991+
if conn == "postgresql_adbc_conn" and not using_string_dtype():
990992
request.node.add_marker(
991993
pytest.mark.xfail(
992-
reason="postgres ADBC driver cannot insert index with null type",
993-
strict=True,
994+
reason="postgres ADBC driver < 1.2 cannot insert index with null type",
994995
)
995996
)
997+
996998
# GH 51086 if conn is sqlite_engine
997999
conn = request.getfixturevalue(conn)
9981000
empty_df = test_frame1.iloc[:0]
@@ -3571,7 +3573,8 @@ def test_read_sql_dtype_backend(
35713573
result = getattr(pd, func)(
35723574
f"Select * from {table}", conn, dtype_backend=dtype_backend
35733575
)
3574-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3576+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3577+
35753578
tm.assert_frame_equal(result, expected)
35763579

35773580
if "adbc" in conn_name:
@@ -3621,7 +3624,7 @@ def test_read_sql_dtype_backend_table(
36213624

36223625
with pd.option_context("mode.string_storage", string_storage):
36233626
result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend)
3624-
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
3627+
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
36253628
tm.assert_frame_equal(result, expected)
36263629

36273630
if "adbc" in conn_name:
@@ -4150,7 +4153,7 @@ def tquery(query, con=None):
41504153
def test_xsqlite_basic(sqlite_buildin):
41514154
frame = DataFrame(
41524155
np.random.default_rng(2).standard_normal((10, 4)),
4153-
columns=Index(list("ABCD"), dtype=object),
4156+
columns=Index(list("ABCD")),
41544157
index=date_range("2000-01-01", periods=10, freq="B"),
41554158
)
41564159
assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10
@@ -4177,7 +4180,7 @@ def test_xsqlite_basic(sqlite_buildin):
41774180
def test_xsqlite_write_row_by_row(sqlite_buildin):
41784181
frame = DataFrame(
41794182
np.random.default_rng(2).standard_normal((10, 4)),
4180-
columns=Index(list("ABCD"), dtype=object),
4183+
columns=Index(list("ABCD")),
41814184
index=date_range("2000-01-01", periods=10, freq="B"),
41824185
)
41834186
frame.iloc[0, 0] = np.nan
@@ -4200,7 +4203,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin):
42004203
def test_xsqlite_execute(sqlite_buildin):
42014204
frame = DataFrame(
42024205
np.random.default_rng(2).standard_normal((10, 4)),
4203-
columns=Index(list("ABCD"), dtype=object),
4206+
columns=Index(list("ABCD")),
42044207
index=date_range("2000-01-01", periods=10, freq="B"),
42054208
)
42064209
create_sql = sql.get_schema(frame, "test")
@@ -4221,7 +4224,7 @@ def test_xsqlite_execute(sqlite_buildin):
42214224
def test_xsqlite_schema(sqlite_buildin):
42224225
frame = DataFrame(
42234226
np.random.default_rng(2).standard_normal((10, 4)),
4224-
columns=Index(list("ABCD"), dtype=object),
4227+
columns=Index(list("ABCD")),
42254228
index=date_range("2000-01-01", periods=10, freq="B"),
42264229
)
42274230
create_sql = sql.get_schema(frame, "test")

0 commit comments

Comments
 (0)