Skip to content

Commit 3083ae9

Browse files
authored
BUG: to_sql with ArrowExtesionArray (#52058)
* BUG: to_sql with ArrowExtesionArray * Remove unneeded fixture * ns * pandas CI not being set * to string * Undo ubuntu workflow
1 parent 14affe0 commit 3083ae9

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

pandas/core/arrays/arrow/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2091,7 +2091,10 @@ def _dt_round(
20912091
return self._round_temporally("round", freq, ambiguous, nonexistent)
20922092

20932093
def _dt_to_pydatetime(self):
2094-
return np.array(self._pa_array.to_pylist(), dtype=object)
2094+
data = self._pa_array.to_pylist()
2095+
if self._dtype.pyarrow_dtype.unit == "ns":
2096+
data = [ts.to_pydatetime(warn=False) for ts in data]
2097+
return np.array(data, dtype=object)
20952098

20962099
def _dt_tz_localize(
20972100
self,

pandas/io/sql.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -964,14 +964,16 @@ def insert_data(self) -> tuple[list[str], list[np.ndarray]]:
964964
data_list: list[np.ndarray] = [None] * ncols # type: ignore[list-item]
965965

966966
for i, (_, ser) in enumerate(temp.items()):
967-
vals = ser._values
968-
if vals.dtype.kind == "M":
969-
d = vals.to_pydatetime()
970-
elif vals.dtype.kind == "m":
967+
if ser.dtype.kind == "M":
968+
d = ser.dt.to_pydatetime()
969+
elif ser.dtype.kind == "m":
970+
vals = ser._values
971+
if isinstance(vals, ArrowExtensionArray):
972+
vals = vals.to_numpy(dtype=np.dtype("m8[ns]"))
971973
# store as integers, see GH#6921, GH#7076
972974
d = vals.view("i8").astype(object)
973975
else:
974-
d = vals.astype(object)
976+
d = ser._values.astype(object)
975977

976978
assert isinstance(d, np.ndarray), type(d)
977979

pandas/tests/extension/test_arrow.py

+1
Original file line numberDiff line numberDiff line change
@@ -2271,6 +2271,7 @@ def test_dt_to_pydatetime():
22712271
result = ser.dt.to_pydatetime()
22722272
expected = np.array(data, dtype=object)
22732273
tm.assert_numpy_array_equal(result, expected)
2274+
assert all(type(res) is datetime for res in result)
22742275

22752276
expected = ser.astype("datetime64[ns]").dt.to_pydatetime()
22762277
tm.assert_numpy_array_equal(result, expected)

pandas/tests/io/test_sql.py

+21
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
date,
2525
datetime,
2626
time,
27+
timedelta,
2728
)
2829
from io import StringIO
2930
from pathlib import Path
@@ -549,6 +550,26 @@ def test_dataframe_to_sql(conn, test_frame1, request):
549550
test_frame1.to_sql("test", conn, if_exists="append", index=False)
550551

551552

553+
@pytest.mark.db
554+
@pytest.mark.parametrize("conn", all_connectable)
555+
def test_dataframe_to_sql_arrow_dtypes(conn, request):
556+
# GH 52046
557+
pytest.importorskip("pyarrow")
558+
df = DataFrame(
559+
{
560+
"int": pd.array([1], dtype="int8[pyarrow]"),
561+
"datetime": pd.array(
562+
[datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]"
563+
),
564+
"timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"),
565+
"string": pd.array(["a"], dtype="string[pyarrow]"),
566+
}
567+
)
568+
conn = request.getfixturevalue(conn)
569+
with tm.assert_produces_warning(UserWarning, match="the 'timedelta'"):
570+
df.to_sql("test_arrow", conn, if_exists="replace", index=False)
571+
572+
552573
@pytest.mark.db
553574
@pytest.mark.parametrize("conn", all_connectable)
554575
@pytest.mark.parametrize("method", [None, "multi"])

0 commit comments

Comments
 (0)