Skip to content

Commit 7d5d123

Browse files
authored
Backport PR #52058 on branch 2.0.x (BUG: to_sql with ArrowExtesionArray) (#52124)
* Backport PR #52058: BUG: to_sql with ArrowExtesionArray * _data
1 parent 22e7c08 commit 7d5d123

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

pandas/core/arrays/arrow/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2075,7 +2075,10 @@ def _dt_round(
20752075
return self._round_temporally("round", freq, ambiguous, nonexistent)
20762076

20772077
def _dt_to_pydatetime(self):
2078-
return np.array(self._data.to_pylist(), dtype=object)
2078+
data = self._data.to_pylist()
2079+
if self._dtype.pyarrow_dtype.unit == "ns":
2080+
data = [ts.to_pydatetime(warn=False) for ts in data]
2081+
return np.array(data, dtype=object)
20792082

20802083
def _dt_tz_localize(
20812084
self,

pandas/io/sql.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -961,14 +961,16 @@ def insert_data(self) -> tuple[list[str], list[np.ndarray]]:
961961
data_list: list[np.ndarray] = [None] * ncols # type: ignore[list-item]
962962

963963
for i, (_, ser) in enumerate(temp.items()):
964-
vals = ser._values
965-
if vals.dtype.kind == "M":
966-
d = vals.to_pydatetime()
967-
elif vals.dtype.kind == "m":
964+
if ser.dtype.kind == "M":
965+
d = ser.dt.to_pydatetime()
966+
elif ser.dtype.kind == "m":
967+
vals = ser._values
968+
if isinstance(vals, ArrowExtensionArray):
969+
vals = vals.to_numpy(dtype=np.dtype("m8[ns]"))
968970
# store as integers, see GH#6921, GH#7076
969971
d = vals.view("i8").astype(object)
970972
else:
971-
d = vals.astype(object)
973+
d = ser._values.astype(object)
972974

973975
assert isinstance(d, np.ndarray), type(d)
974976

pandas/tests/extension/test_arrow.py

+1
Original file line numberDiff line numberDiff line change
@@ -2270,6 +2270,7 @@ def test_dt_to_pydatetime():
22702270
result = ser.dt.to_pydatetime()
22712271
expected = np.array(data, dtype=object)
22722272
tm.assert_numpy_array_equal(result, expected)
2273+
assert all(type(res) is datetime for res in result)
22732274

22742275
expected = ser.astype("datetime64[ns]").dt.to_pydatetime()
22752276
tm.assert_numpy_array_equal(result, expected)

pandas/tests/io/test_sql.py

+21
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
date,
2525
datetime,
2626
time,
27+
timedelta,
2728
)
2829
from io import StringIO
2930
from pathlib import Path
@@ -549,6 +550,26 @@ def test_dataframe_to_sql(conn, test_frame1, request):
549550
test_frame1.to_sql("test", conn, if_exists="append", index=False)
550551

551552

553+
@pytest.mark.db
554+
@pytest.mark.parametrize("conn", all_connectable)
555+
def test_dataframe_to_sql_arrow_dtypes(conn, request):
556+
# GH 52046
557+
pytest.importorskip("pyarrow")
558+
df = DataFrame(
559+
{
560+
"int": pd.array([1], dtype="int8[pyarrow]"),
561+
"datetime": pd.array(
562+
[datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]"
563+
),
564+
"timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"),
565+
"string": pd.array(["a"], dtype="string[pyarrow]"),
566+
}
567+
)
568+
conn = request.getfixturevalue(conn)
569+
with tm.assert_produces_warning(UserWarning, match="the 'timedelta'"):
570+
df.to_sql("test_arrow", conn, if_exists="replace", index=False)
571+
572+
552573
@pytest.mark.db
553574
@pytest.mark.parametrize("conn", all_connectable)
554575
@pytest.mark.parametrize("method", [None, "multi"])

0 commit comments

Comments
 (0)