Skip to content

BUG: to_sql with method=callable not returning int raising TypeError #47474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 3, 2022
Merged
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Fixed regressions
Bug fixes
~~~~~~~~~
- The :class:`errors.FutureWarning` raised when passing arguments (other than ``filepath_or_buffer``) as positional in :func:`read_csv` is now raised at the correct stacklevel (:issue:`47385`)
-
- Bug in :meth:`DataFrame.to_sql` when ``method`` was a ``callable`` that did not return an ``int`` and would raise a ``TypeError`` (:issue:`46891`)

.. ---------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,7 +2752,7 @@ def to_sql(
-------
None or int
Number of rows affected by to_sql. None is returned if the callable
passed into ``method`` does not return the number of rows.
passed into ``method`` does not return an integer number of rows.

The number of returned rows affected is the sum of the ``rowcount``
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
Expand Down
15 changes: 9 additions & 6 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pandas.core.dtypes.common import (
is_datetime64tz_dtype,
is_dict_like,
is_integer,
is_list_like,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -668,7 +669,7 @@ def to_sql(
-------
None or int
Number of rows affected by to_sql. None is returned if the callable
passed into ``method`` does not return the number of rows.
passed into ``method`` does not return an integer number of rows.

.. versionadded:: 1.4.0

Expand Down Expand Up @@ -933,7 +934,7 @@ def insert(
raise ValueError("chunksize argument should be non-zero")

chunks = (nrows // chunksize) + 1
total_inserted = 0
total_inserted = None
with self.pd_sql.run_transaction() as conn:
for i in range(chunks):
start_i = i * chunksize
Expand All @@ -943,10 +944,12 @@ def insert(

chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
num_inserted = exec_insert(conn, keys, chunk_iter)
if num_inserted is None:
total_inserted = None
else:
total_inserted += num_inserted
# GH 46891
if is_integer(num_inserted):
if total_inserted is None:
total_inserted = num_inserted
else:
total_inserted += num_inserted
return total_inserted

def _query_iterator(
Expand Down
13 changes: 11 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,8 @@ def test_read_procedure(conn, request):

@pytest.mark.db
@pytest.mark.parametrize("conn", postgresql_connectable)
def test_copy_from_callable_insertion_method(conn, request):
@pytest.mark.parametrize("expected_count", [2, "Success!"])
def test_copy_from_callable_insertion_method(conn, expected_count, request):
# GH 8953
# Example in io.rst found under _io.sql.method
# not available in sqlite, mysql
Expand All @@ -641,10 +642,18 @@ def psql_insert_copy(table, conn, keys, data_iter):

sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV"
cur.copy_expert(sql=sql_query, file=s_buf)
return expected_count

conn = request.getfixturevalue(conn)
expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})
expected.to_sql("test_frame", conn, index=False, method=psql_insert_copy)
result_count = expected.to_sql(
"test_frame", conn, index=False, method=psql_insert_copy
)
# GH 46891
if not isinstance(expected_count, int):
assert result_count is None
else:
assert result_count == expected_count
result = sql.read_sql_table("test_frame", conn)
tm.assert_frame_equal(result, expected)

Expand Down