Skip to content

ENH: to_sql returns rowcount #45137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ Other enhancements
- :meth:`is_list_like` now identifies duck-arrays as list-like unless ``.ndim == 0`` (:issue:`35131`)
- :class:`ExtensionDtype` and :class:`ExtensionArray` are now (de)serialized when exporting a :class:`DataFrame` with :meth:`DataFrame.to_json` using ``orient='table'`` (:issue:`20612`, :issue:`44705`).
- Add support for `Zstandard <http://facebook.github.io/zstd/>`_ compression to :meth:`DataFrame.to_pickle`/:meth:`read_pickle` and friends (:issue:`43925`)
-
- :meth:`DataFrame.to_sql` now returns an ``int`` of the number of written rows (:issue:`23998`)


.. ---------------------------------------------------------------------------
Expand Down
25 changes: 22 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2767,7 +2767,7 @@ def to_sql(
chunksize=None,
dtype: DtypeArg | None = None,
method=None,
) -> None:
) -> int | None:
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -2820,6 +2820,20 @@ def to_sql(
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
Returns
-------
None or int
Number of rows affected by to_sql. None is returned if the callable
passed into ``method`` does not return the number of rows.
The number of returned rows affected is the sum of the ``rowcount``
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
reflect the exact number of written rows as stipulated in the
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
.. versionadded:: 1.4.0
Raises
------
ValueError
Expand Down Expand Up @@ -2859,6 +2873,7 @@ def to_sql(
2 User 3
>>> df.to_sql('users', con=engine)
3
>>> engine.execute("SELECT * FROM users").fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]
Expand All @@ -2867,12 +2882,14 @@ def to_sql(
>>> with engine.begin() as connection:
... df1 = pd.DataFrame({'name' : ['User 4', 'User 5']})
... df1.to_sql('users', con=connection, if_exists='append')
2
This is allowed to support operations that require that the same
DBAPI connection is used for the entire operation.
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
>>> df2.to_sql('users', con=engine, if_exists='append')
2
>>> engine.execute("SELECT * FROM users").fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
Expand All @@ -2882,6 +2899,7 @@ def to_sql(
>>> df2.to_sql('users', con=engine, if_exists='replace',
... index_label='id')
2
>>> engine.execute("SELECT * FROM users").fetchall()
[(0, 'User 6'), (1, 'User 7')]
Expand All @@ -2900,13 +2918,14 @@ def to_sql(
>>> from sqlalchemy.types import Integer
>>> df.to_sql('integers', con=engine, index=False,
... dtype={"A": Integer()})
3
>>> engine.execute("SELECT * FROM integers").fetchall()
[(1,), (None,), (2,)]
"""
""" # noqa:E501
from pandas.io import sql

sql.to_sql(
return sql.to_sql(
self,
name,
con,
Expand Down
75 changes: 52 additions & 23 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def read_sql(
>>> df = pd.DataFrame(data=[[0, '10/11/12'], [1, '12/11/10']],
... columns=['int_column', 'date_column'])
>>> df.to_sql('test_data', conn)
2
>>> pd.read_sql('SELECT int_column, date_column FROM test_data', conn)
int_column date_column
Expand Down Expand Up @@ -611,7 +612,7 @@ def to_sql(
method: str | None = None,
engine: str = "auto",
**engine_kwargs,
) -> None:
) -> int | None:
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -650,8 +651,8 @@ def to_sql(
Controls the SQL insertion clause used:
- None : Uses standard SQL ``INSERT`` clause (one per row).
- 'multi': Pass multiple values in a single ``INSERT`` clause.
- callable with signature ``(pd_table, conn, keys, data_iter)``.
- ``'multi'``: Pass multiple values in a single ``INSERT`` clause.
- callable with signature ``(pd_table, conn, keys, data_iter) -> int | None``.
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
Expand All @@ -664,7 +665,23 @@ def to_sql(
**engine_kwargs
Any additional kwargs are passed to the engine.
"""
Returns
-------
None or int
Number of rows affected by to_sql. None is returned if the callable
passed into ``method`` does not return the number of rows.
.. versionadded:: 1.4.0
Notes
-----
The returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor``
or SQLAlchemy connectable. The returned value may not reflect the exact number of written
rows as stipulated in the
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__
""" # noqa:E501
if if_exists not in ("fail", "replace", "append"):
raise ValueError(f"'{if_exists}' is not valid for if_exists")

Expand All @@ -677,7 +694,7 @@ def to_sql(
"'frame' argument should be either a Series or a DataFrame"
)

pandas_sql.to_sql(
return pandas_sql.to_sql(
frame,
name,
if_exists=if_exists,
Expand Down Expand Up @@ -817,7 +834,7 @@ def create(self):
else:
self._execute_create()

def _execute_insert(self, conn, keys: list[str], data_iter):
def _execute_insert(self, conn, keys: list[str], data_iter) -> int:
"""
Execute SQL statement inserting data
Expand All @@ -830,9 +847,10 @@ def _execute_insert(self, conn, keys: list[str], data_iter):
Each item contains a list of values to be inserted
"""
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(self.table.insert(), data)
result = conn.execute(self.table.insert(), data)
return result.rowcount

def _execute_insert_multi(self, conn, keys: list[str], data_iter):
def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int:
"""
Alternative to _execute_insert for DBs support multivalue INSERT.
Expand All @@ -845,7 +863,8 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter):

data = [dict(zip(keys, row)) for row in data_iter]
stmt = insert(self.table).values(data)
conn.execute(stmt)
result = conn.execute(stmt)
return result.rowcount

def insert_data(self):
if self.index is not None:
Expand Down Expand Up @@ -885,7 +904,9 @@ def insert_data(self):

return column_names, data_list

def insert(self, chunksize: int | None = None, method: str | None = None):
def insert(
self, chunksize: int | None = None, method: str | None = None
) -> int | None:

# set insert method
if method is None:
Expand All @@ -902,15 +923,15 @@ def insert(self, chunksize: int | None = None, method: str | None = None):
nrows = len(self.frame)

if nrows == 0:
return
return 0

if chunksize is None:
chunksize = nrows
elif chunksize == 0:
raise ValueError("chunksize argument should be non-zero")

chunks = (nrows // chunksize) + 1

total_inserted = 0
with self.pd_sql.run_transaction() as conn:
for i in range(chunks):
start_i = i * chunksize
Expand All @@ -919,7 +940,12 @@ def insert(self, chunksize: int | None = None, method: str | None = None):
break

chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
exec_insert(conn, keys, chunk_iter)
num_inserted = exec_insert(conn, keys, chunk_iter)
if num_inserted is None:
total_inserted = None
else:
total_inserted += num_inserted
return total_inserted

def _query_iterator(
self,
Expand Down Expand Up @@ -1239,7 +1265,7 @@ def to_sql(
chunksize=None,
dtype: DtypeArg | None = None,
method=None,
):
) -> int | None:
raise ValueError(
"PandasSQL must be created with an SQLAlchemy "
"connectable or sqlite connection"
Expand All @@ -1258,7 +1284,7 @@ def insert_records(
chunksize=None,
method=None,
**engine_kwargs,
):
) -> int | None:
"""
Inserts data into already-prepared table
"""
Expand All @@ -1282,11 +1308,11 @@ def insert_records(
chunksize=None,
method=None,
**engine_kwargs,
):
) -> int | None:
from sqlalchemy import exc

try:
table.insert(chunksize=chunksize, method=method)
return table.insert(chunksize=chunksize, method=method)
except exc.SQLAlchemyError as err:
# GH34431
# https://stackoverflow.com/a/67358288/6067848
Expand Down Expand Up @@ -1643,7 +1669,7 @@ def to_sql(
method=None,
engine="auto",
**engine_kwargs,
):
) -> int | None:
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -1704,7 +1730,7 @@ def to_sql(
dtype=dtype,
)

sql_engine.insert_records(
total_inserted = sql_engine.insert_records(
table=table,
con=self.connectable,
frame=frame,
Expand All @@ -1717,6 +1743,7 @@ def to_sql(
)

self.check_case_sensitive(name=name, schema=schema)
return total_inserted

@property
def tables(self):
Expand Down Expand Up @@ -1859,14 +1886,16 @@ def insert_statement(self, *, num_rows: int):
)
return insert_statement

def _execute_insert(self, conn, keys, data_iter):
def _execute_insert(self, conn, keys, data_iter) -> int:
data_list = list(data_iter)
conn.executemany(self.insert_statement(num_rows=1), data_list)
return conn.rowcount

def _execute_insert_multi(self, conn, keys, data_iter):
def _execute_insert_multi(self, conn, keys, data_iter) -> int:
data_list = list(data_iter)
flattened_data = [x for row in data_list for x in row]
conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data)
return conn.rowcount

def _create_table_setup(self):
"""
Expand Down Expand Up @@ -2088,7 +2117,7 @@ def to_sql(
dtype: DtypeArg | None = None,
method=None,
**kwargs,
):
) -> int | None:
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -2153,7 +2182,7 @@ def to_sql(
dtype=dtype,
)
table.create()
table.insert(chunksize, method)
return table.insert(chunksize, method)

def has_table(self, name: str, schema: str | None = None):

Expand Down
Loading