diff --git a/doc/source/whatsnew/v1.4.4.rst b/doc/source/whatsnew/v1.4.4.rst index 0af25daf0468a..6ee140f59e096 100644 --- a/doc/source/whatsnew/v1.4.4.rst +++ b/doc/source/whatsnew/v1.4.4.rst @@ -24,7 +24,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ - The :class:`errors.FutureWarning` raised when passing arguments (other than ``filepath_or_buffer``) as positional in :func:`read_csv` is now raised at the correct stacklevel (:issue:`47385`) -- +- Bug in :meth:`DataFrame.to_sql` when ``method`` was a ``callable`` that did not return an ``int`` and would raise a ``TypeError`` (:issue:`46891`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f896169d0ae44..647cc6b533275 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2752,7 +2752,7 @@ def to_sql( ------- None or int Number of rows affected by to_sql. None is returned if the callable - passed into ``method`` does not return the number of rows. + passed into ``method`` does not return an integer number of rows. The number of returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 24290b8370ed2..987a19ee0cf47 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -37,6 +37,7 @@ from pandas.core.dtypes.common import ( is_datetime64tz_dtype, is_dict_like, + is_integer, is_list_like, ) from pandas.core.dtypes.dtypes import DatetimeTZDtype @@ -668,7 +669,7 @@ def to_sql( ------- None or int Number of rows affected by to_sql. None is returned if the callable - passed into ``method`` does not return the number of rows. + passed into ``method`` does not return an integer number of rows. .. versionadded:: 1.4.0 @@ -933,7 +934,7 @@ def insert( raise ValueError("chunksize argument should be non-zero") chunks = (nrows // chunksize) + 1 - total_inserted = 0 + total_inserted = None with self.pd_sql.run_transaction() as conn: for i in range(chunks): start_i = i * chunksize @@ -943,10 +944,12 @@ def insert( chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list)) num_inserted = exec_insert(conn, keys, chunk_iter) - if num_inserted is None: - total_inserted = None - else: - total_inserted += num_inserted + # GH 46891 + if is_integer(num_inserted): + if total_inserted is None: + total_inserted = num_inserted + else: + total_inserted += num_inserted return total_inserted def _query_iterator( diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index e28901fa1a1ed..dd3464ccf9f64 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -620,7 +620,8 @@ def test_read_procedure(conn, request): @pytest.mark.db @pytest.mark.parametrize("conn", postgresql_connectable) -def test_copy_from_callable_insertion_method(conn, request): +@pytest.mark.parametrize("expected_count", [2, "Success!"]) +def test_copy_from_callable_insertion_method(conn, expected_count, request): # GH 8953 # Example in io.rst found under _io.sql.method # not available in sqlite, mysql @@ -641,10 +642,18 @@ def psql_insert_copy(table, conn, keys, data_iter): sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV" cur.copy_expert(sql=sql_query, file=s_buf) + return expected_count conn = request.getfixturevalue(conn) expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) - expected.to_sql("test_frame", conn, index=False, method=psql_insert_copy) + result_count = expected.to_sql( + "test_frame", conn, index=False, method=psql_insert_copy + ) + # GH 46891 + if not isinstance(expected_count, int): + assert result_count is None + else: + assert result_count == expected_count result = sql.read_sql_table("test_frame", conn) tm.assert_frame_equal(result, expected)