Skip to content

Commit 2856607

Browse files
authored
BUG: to_sql with method=callable not returning int raising TypeError (#47474)
1 parent 1924be3 commit 2856607

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

doc/source/whatsnew/v1.4.4.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Fixed regressions
2424
Bug fixes
2525
~~~~~~~~~
2626
- The :class:`errors.FutureWarning` raised when passing arguments (other than ``filepath_or_buffer``) as positional in :func:`read_csv` is now raised at the correct stacklevel (:issue:`47385`)
27-
-
27+
- Bug in :meth:`DataFrame.to_sql` when ``method`` was a ``callable`` that did not return an ``int`` and would raise a ``TypeError`` (:issue:`46891`)
2828

2929
.. ---------------------------------------------------------------------------
3030

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2754,7 +2754,7 @@ def to_sql(
27542754
-------
27552755
None or int
27562756
Number of rows affected by to_sql. None is returned if the callable
2757-
passed into ``method`` does not return the number of rows.
2757+
passed into ``method`` does not return an integer number of rows.
27582758
27592759
The number of returned rows affected is the sum of the ``rowcount``
27602760
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not

pandas/io/sql.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pandas.core.dtypes.common import (
3838
is_datetime64tz_dtype,
3939
is_dict_like,
40+
is_integer,
4041
is_list_like,
4142
)
4243
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -668,7 +669,7 @@ def to_sql(
668669
-------
669670
None or int
670671
Number of rows affected by to_sql. None is returned if the callable
671-
passed into ``method`` does not return the number of rows.
672+
passed into ``method`` does not return an integer number of rows.
672673
673674
.. versionadded:: 1.4.0
674675
@@ -933,7 +934,7 @@ def insert(
933934
raise ValueError("chunksize argument should be non-zero")
934935

935936
chunks = (nrows // chunksize) + 1
936-
total_inserted = 0
937+
total_inserted = None
937938
with self.pd_sql.run_transaction() as conn:
938939
for i in range(chunks):
939940
start_i = i * chunksize
@@ -943,10 +944,12 @@ def insert(
943944

944945
chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list))
945946
num_inserted = exec_insert(conn, keys, chunk_iter)
946-
if num_inserted is None:
947-
total_inserted = None
948-
else:
949-
total_inserted += num_inserted
947+
# GH 46891
948+
if is_integer(num_inserted):
949+
if total_inserted is None:
950+
total_inserted = num_inserted
951+
else:
952+
total_inserted += num_inserted
950953
return total_inserted
951954

952955
def _query_iterator(

pandas/tests/io/test_sql.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ def test_read_procedure(conn, request):
620620

621621
@pytest.mark.db
622622
@pytest.mark.parametrize("conn", postgresql_connectable)
623-
def test_copy_from_callable_insertion_method(conn, request):
623+
@pytest.mark.parametrize("expected_count", [2, "Success!"])
624+
def test_copy_from_callable_insertion_method(conn, expected_count, request):
624625
# GH 8953
625626
# Example in io.rst found under _io.sql.method
626627
# not available in sqlite, mysql
@@ -641,10 +642,18 @@ def psql_insert_copy(table, conn, keys, data_iter):
641642

642643
sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV"
643644
cur.copy_expert(sql=sql_query, file=s_buf)
645+
return expected_count
644646

645647
conn = request.getfixturevalue(conn)
646648
expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})
647-
expected.to_sql("test_frame", conn, index=False, method=psql_insert_copy)
649+
result_count = expected.to_sql(
650+
"test_frame", conn, index=False, method=psql_insert_copy
651+
)
652+
# GH 46891
653+
if not isinstance(expected_count, int):
654+
assert result_count is None
655+
else:
656+
assert result_count == expected_count
648657
result = sql.read_sql_table("test_frame", conn)
649658
tm.assert_frame_equal(result, expected)
650659

0 commit comments

Comments
 (0)