Skip to content

Commit 3dfed3f

Browse files
authored
ENH: to_sql returns rowcount (#45137)
1 parent deca954 commit 3dfed3f

File tree

4 files changed

+213
-118
lines changed

4 files changed

+213
-118
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ Other enhancements
273273
- :meth:`is_list_like` now identifies duck-arrays as list-like unless ``.ndim == 0`` (:issue:`35131`)
274274
- :class:`ExtensionDtype` and :class:`ExtensionArray` are now (de)serialized when exporting a :class:`DataFrame` with :meth:`DataFrame.to_json` using ``orient='table'`` (:issue:`20612`, :issue:`44705`).
275275
- Add support for `Zstandard <http://facebook.github.io/zstd/>`_ compression to :meth:`DataFrame.to_pickle`/:meth:`read_pickle` and friends (:issue:`43925`)
276-
-
276+
- :meth:`DataFrame.to_sql` now returns an ``int`` of the number of written rows (:issue:`23998`)
277277

278278

279279
.. ---------------------------------------------------------------------------

pandas/core/generic.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -2767,7 +2767,7 @@ def to_sql(
27672767
chunksize=None,
27682768
dtype: DtypeArg | None = None,
27692769
method=None,
2770-
) -> None:
2770+
) -> int | None:
27712771
"""
27722772
Write records stored in a DataFrame to a SQL database.
27732773
@@ -2820,6 +2820,20 @@ def to_sql(
28202820
Details and a sample callable implementation can be found in the
28212821
section :ref:`insert method <io.sql.method>`.
28222822
2823+
Returns
2824+
-------
2825+
None or int
2826+
Number of rows affected by to_sql. None is returned if the callable
2827+
passed into ``method`` does not return the number of rows.
2828+
2829+
The number of returned rows affected is the sum of the ``rowcount``
2830+
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
2831+
reflect the exact number of written rows as stipulated in the
2832+
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
2833+
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
2834+
2835+
.. versionadded:: 1.4.0
2836+
28232837
Raises
28242838
------
28252839
ValueError
@@ -2859,6 +2873,7 @@ def to_sql(
28592873
2 User 3
28602874
28612875
>>> df.to_sql('users', con=engine)
2876+
3
28622877
>>> engine.execute("SELECT * FROM users").fetchall()
28632878
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]
28642879
@@ -2867,12 +2882,14 @@ def to_sql(
28672882
>>> with engine.begin() as connection:
28682883
... df1 = pd.DataFrame({'name' : ['User 4', 'User 5']})
28692884
... df1.to_sql('users', con=connection, if_exists='append')
2885+
2
28702886
28712887
This is allowed to support operations that require that the same
28722888
DBAPI connection is used for the entire operation.
28732889
28742890
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
28752891
>>> df2.to_sql('users', con=engine, if_exists='append')
2892+
2
28762893
>>> engine.execute("SELECT * FROM users").fetchall()
28772894
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
28782895
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
@@ -2882,6 +2899,7 @@ def to_sql(
28822899
28832900
>>> df2.to_sql('users', con=engine, if_exists='replace',
28842901
... index_label='id')
2902+
2
28852903
>>> engine.execute("SELECT * FROM users").fetchall()
28862904
[(0, 'User 6'), (1, 'User 7')]
28872905
@@ -2900,13 +2918,14 @@ def to_sql(
29002918
>>> from sqlalchemy.types import Integer
29012919
>>> df.to_sql('integers', con=engine, index=False,
29022920
... dtype={"A": Integer()})
2921+
3
29032922
29042923
>>> engine.execute("SELECT * FROM integers").fetchall()
29052924
[(1,), (None,), (2,)]
2906-
"""
2925+
""" # noqa:E501
29072926
from pandas.io import sql
29082927

2909-
sql.to_sql(
2928+
return sql.to_sql(
29102929
self,
29112930
name,
29122931
con,

pandas/io/sql.py

+52-23
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ def read_sql(
512512
>>> df = pd.DataFrame(data=[[0, '10/11/12'], [1, '12/11/10']],
513513
... columns=['int_column', 'date_column'])
514514
>>> df.to_sql('test_data', conn)
515+
2
515516
516517
>>> pd.read_sql('SELECT int_column, date_column FROM test_data', conn)
517518
int_column date_column
@@ -611,7 +612,7 @@ def to_sql(
611612
method: str | None = None,
612613
engine: str = "auto",
613614
**engine_kwargs,
614-
) -> None:
615+
) -> int | None:
615616
"""
616617
Write records stored in a DataFrame to a SQL database.
617618
@@ -650,8 +651,8 @@ def to_sql(
650651
Controls the SQL insertion clause used:
651652
652653
- None : Uses standard SQL ``INSERT`` clause (one per row).
653-
- 'multi': Pass multiple values in a single ``INSERT`` clause.
654-
- callable with signature ``(pd_table, conn, keys, data_iter)``.
654+
- ``'multi'``: Pass multiple values in a single ``INSERT`` clause.
655+
- callable with signature ``(pd_table, conn, keys, data_iter) -> int | None``.
655656
656657
Details and a sample callable implementation can be found in the
657658
section :ref:`insert method <io.sql.method>`.
@@ -664,7 +665,23 @@ def to_sql(
664665
665666
**engine_kwargs
666667
Any additional kwargs are passed to the engine.
667-
"""
668+
669+
Returns
670+
-------
671+
None or int
672+
Number of rows affected by to_sql. None is returned if the callable
673+
passed into ``method`` does not return the number of rows.
674+
675+
.. versionadded:: 1.4.0
676+
677+
Notes
678+
-----
679+
The returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor``
680+
or SQLAlchemy connectable. The returned value may not reflect the exact number of written
681+
rows as stipulated in the
682+
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
683+
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__
684+
""" # noqa:E501
668685
if if_exists not in ("fail", "replace", "append"):
669686
raise ValueError(f"'{if_exists}' is not valid for if_exists")
670687

@@ -677,7 +694,7 @@ def to_sql(
677694
"'frame' argument should be either a Series or a DataFrame"
678695
)
679696

680-
pandas_sql.to_sql(
697+
return pandas_sql.to_sql(
681698
frame,
682699
name,
683700
if_exists=if_exists,
@@ -817,7 +834,7 @@ def create(self):
817834
else:
818835
self._execute_create()
819836

820-
def _execute_insert(self, conn, keys: list[str], data_iter):
837+
def _execute_insert(self, conn, keys: list[str], data_iter) -> int:
821838
"""
822839
Execute SQL statement inserting data
823840
@@ -830,9 +847,10 @@ def _execute_insert(self, conn, keys: list[str], data_iter):
830847
Each item contains a list of values to be inserted
831848
"""
832849
data = [dict(zip(keys, row)) for row in data_iter]
833-
conn.execute(self.table.insert(), data)
850+
result = conn.execute(self.table.insert(), data)
851+
return result.rowcount
834852

835-
def _execute_insert_multi(self, conn, keys: list[str], data_iter):
853+
def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int:
836854
"""
837855
Alternative to _execute_insert for DBs support multivalue INSERT.
838856
@@ -845,7 +863,8 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter):
845863

846864
data = [dict(zip(keys, row)) for row in data_iter]
847865
stmt = insert(self.table).values(data)
848-
conn.execute(stmt)
866+
result = conn.execute(stmt)
867+
return result.rowcount
849868

850869
def insert_data(self):
851870
if self.index is not None:
@@ -885,7 +904,9 @@ def insert_data(self):
885904

886905
return column_names, data_list
887906

888-
def insert(self, chunksize: int | None = None, method: str | None = None):
907+
def insert(
908+
self, chunksize: int | None = None, method: str | None = None
909+
) -> int | None:
889910

890911
# set insert method
891912
if method is None:
@@ -902,15 +923,15 @@ def insert(self, chunksize: int | None = None, method: str | None = None):
902923
nrows = len(self.frame)
903924

904925
if nrows == 0:
905-
return
926+
return 0
906927

907928
if chunksize is None:
908929
chunksize = nrows
909930
elif chunksize == 0:
910931
raise ValueError("chunksize argument should be non-zero")
911932

912933
chunks = (nrows // chunksize) + 1
913-
934+
total_inserted = 0
914935
with self.pd_sql.run_transaction() as conn:
915936
for i in range(chunks):
916937
start_i = i * chunksize
@@ -919,7 +940,12 @@ def insert(self, chunksize: int | None = None, method: str | None = None):
919940
break
920941

921942
chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
922-
exec_insert(conn, keys, chunk_iter)
943+
num_inserted = exec_insert(conn, keys, chunk_iter)
944+
if num_inserted is None:
945+
total_inserted = None
946+
else:
947+
total_inserted += num_inserted
948+
return total_inserted
923949

924950
def _query_iterator(
925951
self,
@@ -1239,7 +1265,7 @@ def to_sql(
12391265
chunksize=None,
12401266
dtype: DtypeArg | None = None,
12411267
method=None,
1242-
):
1268+
) -> int | None:
12431269
raise ValueError(
12441270
"PandasSQL must be created with an SQLAlchemy "
12451271
"connectable or sqlite connection"
@@ -1258,7 +1284,7 @@ def insert_records(
12581284
chunksize=None,
12591285
method=None,
12601286
**engine_kwargs,
1261-
):
1287+
) -> int | None:
12621288
"""
12631289
Inserts data into already-prepared table
12641290
"""
@@ -1282,11 +1308,11 @@ def insert_records(
12821308
chunksize=None,
12831309
method=None,
12841310
**engine_kwargs,
1285-
):
1311+
) -> int | None:
12861312
from sqlalchemy import exc
12871313

12881314
try:
1289-
table.insert(chunksize=chunksize, method=method)
1315+
return table.insert(chunksize=chunksize, method=method)
12901316
except exc.SQLAlchemyError as err:
12911317
# GH34431
12921318
# https://stackoverflow.com/a/67358288/6067848
@@ -1643,7 +1669,7 @@ def to_sql(
16431669
method=None,
16441670
engine="auto",
16451671
**engine_kwargs,
1646-
):
1672+
) -> int | None:
16471673
"""
16481674
Write records stored in a DataFrame to a SQL database.
16491675
@@ -1704,7 +1730,7 @@ def to_sql(
17041730
dtype=dtype,
17051731
)
17061732

1707-
sql_engine.insert_records(
1733+
total_inserted = sql_engine.insert_records(
17081734
table=table,
17091735
con=self.connectable,
17101736
frame=frame,
@@ -1717,6 +1743,7 @@ def to_sql(
17171743
)
17181744

17191745
self.check_case_sensitive(name=name, schema=schema)
1746+
return total_inserted
17201747

17211748
@property
17221749
def tables(self):
@@ -1859,14 +1886,16 @@ def insert_statement(self, *, num_rows: int):
18591886
)
18601887
return insert_statement
18611888

1862-
def _execute_insert(self, conn, keys, data_iter):
1889+
def _execute_insert(self, conn, keys, data_iter) -> int:
18631890
data_list = list(data_iter)
18641891
conn.executemany(self.insert_statement(num_rows=1), data_list)
1892+
return conn.rowcount
18651893

1866-
def _execute_insert_multi(self, conn, keys, data_iter):
1894+
def _execute_insert_multi(self, conn, keys, data_iter) -> int:
18671895
data_list = list(data_iter)
18681896
flattened_data = [x for row in data_list for x in row]
18691897
conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data)
1898+
return conn.rowcount
18701899

18711900
def _create_table_setup(self):
18721901
"""
@@ -2088,7 +2117,7 @@ def to_sql(
20882117
dtype: DtypeArg | None = None,
20892118
method=None,
20902119
**kwargs,
2091-
):
2120+
) -> int | None:
20922121
"""
20932122
Write records stored in a DataFrame to a SQL database.
20942123
@@ -2153,7 +2182,7 @@ def to_sql(
21532182
dtype=dtype,
21542183
)
21552184
table.create()
2156-
table.insert(chunksize, method)
2185+
return table.insert(chunksize, method)
21572186

21582187
def has_table(self, name: str, schema: str | None = None):
21592188

0 commit comments

Comments
 (0)