From 654ca5fc7300519cc1eb1e6a0431840633f1b577 Mon Sep 17 00:00:00 2001 From: Chuck Cadman Date: Thu, 15 Dec 2022 20:32:18 -0800 Subject: [PATCH] BUG: Allow read_sql to work with chunksize. --- pandas/io/sql.py | 138 +++++++++++++++++++++--------------- pandas/tests/io/test_sql.py | 91 ++++++++++++++++++++---- 2 files changed, 160 insertions(+), 69 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 2c98ff61cbef6..62a54a548a990 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -9,7 +9,10 @@ ABC, abstractmethod, ) -from contextlib import contextmanager +from contextlib import ( + ExitStack, + contextmanager, +) from datetime import ( date, datetime, @@ -69,6 +72,14 @@ # -- Helper functions +def _cleanup_after_generator(generator, exit_stack: ExitStack): + """Does the cleanup after iterating through the generator.""" + try: + yield from generator + finally: + exit_stack.close() + + def _convert_params(sql, params): """Convert SQL and params args to DBAPI2.0 compliant format.""" args = [sql] @@ -772,12 +783,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool: table_exists = has_table -@contextmanager def pandasSQL_builder( con, schema: str | None = None, need_transaction: bool = False, -) -> Iterator[PandasSQL]: +) -> PandasSQL: """ Convenience function to return the correct PandasSQL subclass based on the provided parameters. Also creates a sqlalchemy connection and transaction @@ -786,45 +796,24 @@ def pandasSQL_builder( import sqlite3 if isinstance(con, sqlite3.Connection) or con is None: - yield SQLiteDatabase(con) - else: - sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + return SQLiteDatabase(con) - if sqlalchemy is not None and isinstance( - con, (str, sqlalchemy.engine.Connectable) - ): - with _sqlalchemy_con(con, need_transaction) as con: - yield SQLDatabase(con, schema=schema) - elif isinstance(con, str) and sqlalchemy is None: - raise ImportError("Using URI string without sqlalchemy installed.") - else: - - warnings.warn( - "pandas only supports SQLAlchemy connectable (engine/connection) or " - "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 " - "objects are not tested. Please consider using SQLAlchemy.", - UserWarning, - stacklevel=find_stack_level() + 2, - ) - yield SQLiteDatabase(con) + sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + if isinstance(con, str) and sqlalchemy is None: + raise ImportError("Using URI string without sqlalchemy installed.") -@contextmanager -def _sqlalchemy_con(connectable, need_transaction: bool): - """Create a sqlalchemy connection and a transaction if necessary.""" - sqlalchemy = import_optional_dependency("sqlalchemy", errors="raise") + if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)): + return SQLDatabase(con, schema, need_transaction) - if isinstance(connectable, str): - connectable = sqlalchemy.create_engine(connectable) - if isinstance(connectable, sqlalchemy.engine.Engine): - with connectable.connect() as con: - if need_transaction: - with con.begin(): - yield con - else: - yield con - else: - yield connectable + warnings.warn( + "pandas only supports SQLAlchemy connectable (engine/connection) or " + "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 " + "objects are not tested. Please consider using SQLAlchemy.", + UserWarning, + stacklevel=find_stack_level(), + ) + return SQLiteDatabase(con) class SQLTable(PandasObject): @@ -1049,6 +1038,7 @@ def _query_iterator( def read( self, + exit_stack: ExitStack, coerce_float: bool = True, parse_dates=None, columns=None, @@ -1069,13 +1059,16 @@ def read( column_names = result.keys() if chunksize is not None: - return self._query_iterator( - result, - chunksize, - column_names, - coerce_float=coerce_float, - parse_dates=parse_dates, - use_nullable_dtypes=use_nullable_dtypes, + return _cleanup_after_generator( + self._query_iterator( + result, + chunksize, + column_names, + coerce_float=coerce_float, + parse_dates=parse_dates, + use_nullable_dtypes=use_nullable_dtypes, + ), + exit_stack, ) else: data = result.fetchall() @@ -1327,6 +1320,12 @@ class PandasSQL(PandasObject, ABC): Subclasses Should define read_query and to_sql. """ + def __enter__(self): + return self + + def __exit__(self, *args) -> None: + pass + def read_table( self, table_name: str, @@ -1482,20 +1481,38 @@ class SQLDatabase(PandasSQL): Parameters ---------- - con : SQLAlchemy Connection - Connection to connect with the database. Using SQLAlchemy makes it + con : SQLAlchemy Connectable or URI string. + Connectable to connect with the database. Using SQLAlchemy makes it possible to use any DB supported by that library. schema : string, default None Name of SQL schema in database to write to (if database flavor supports this). If None, use default schema (default). + need_transaction : bool, default False + If True, SQLDatabase will create a transaction. """ - def __init__(self, con, schema: str | None = None) -> None: + def __init__( + self, con, schema: str | None = None, need_transaction: bool = False + ) -> None: + from sqlalchemy import create_engine + from sqlalchemy.engine import Engine from sqlalchemy.schema import MetaData + self.exit_stack = ExitStack() + if isinstance(con, str): + con = create_engine(con) + if isinstance(con, Engine): + con = self.exit_stack.enter_context(con.connect()) + if need_transaction: + self.exit_stack.enter_context(con.begin()) self.con = con self.meta = MetaData(schema=schema) + self.returns_generator = False + + def __exit__(self, *args) -> None: + if not self.returns_generator: + self.exit_stack.close() @contextmanager def run_transaction(self): @@ -1566,7 +1583,10 @@ def read_table( """ self.meta.reflect(bind=self.con, only=[table_name]) table = SQLTable(table_name, self, index=index_col, schema=schema) + if chunksize is not None: + self.returns_generator = True return table.read( + self.exit_stack, coerce_float=coerce_float, parse_dates=parse_dates, columns=columns, @@ -1675,15 +1695,19 @@ def read_query( columns = result.keys() if chunksize is not None: - return self._query_iterator( - result, - chunksize, - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - use_nullable_dtypes=use_nullable_dtypes, + self.returns_generator = True + return _cleanup_after_generator( + self._query_iterator( + result, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, + ), + self.exit_stack, ) else: data = result.fetchall() diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 490b425ee52bf..b7cff1627a81f 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -260,24 +260,34 @@ def check_iris_frame(frame: DataFrame): row = frame.iloc[0] assert issubclass(pytype, np.floating) tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]) + assert frame.shape in ((150, 5), (8, 5)) def count_rows(conn, table_name: str): stmt = f"SELECT count(*) AS count_1 FROM {table_name}" if isinstance(conn, sqlite3.Connection): cur = conn.cursor() - result = cur.execute(stmt) + return cur.execute(stmt).fetchone()[0] else: - from sqlalchemy import text + from sqlalchemy import ( + create_engine, + text, + ) from sqlalchemy.engine import Engine stmt = text(stmt) - if isinstance(conn, Engine): + if isinstance(conn, str): + try: + engine = create_engine(conn) + with engine.connect() as conn: + return conn.execute(stmt).scalar_one() + finally: + engine.dispose() + elif isinstance(conn, Engine): with conn.connect() as conn: - result = conn.execute(stmt) + return conn.execute(stmt).scalar_one() else: - result = conn.execute(stmt) - return result.fetchone()[0] + return conn.execute(stmt).scalar_one() @pytest.fixture @@ -388,6 +398,7 @@ def mysql_pymysql_engine(iris_path, types_data): engine = sqlalchemy.create_engine( "mysql+pymysql://root@localhost:3306/pandas", connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS}, + poolclass=sqlalchemy.pool.NullPool, ) insp = sqlalchemy.inspect(engine) if not insp.has_table("iris"): @@ -414,7 +425,8 @@ def postgresql_psycopg2_engine(iris_path, types_data): sqlalchemy = pytest.importorskip("sqlalchemy") pytest.importorskip("psycopg2") engine = sqlalchemy.create_engine( - "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas" + "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas", + poolclass=sqlalchemy.pool.NullPool, ) insp = sqlalchemy.inspect(engine) if not insp.has_table("iris"): @@ -435,9 +447,16 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine): @pytest.fixture -def sqlite_engine(): +def sqlite_str(): + pytest.importorskip("sqlalchemy") + with tm.ensure_clean() as name: + yield "sqlite:///" + name + + +@pytest.fixture +def sqlite_engine(sqlite_str): sqlalchemy = pytest.importorskip("sqlalchemy") - engine = sqlalchemy.create_engine("sqlite://") + engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool) yield engine engine.dispose() @@ -447,6 +466,15 @@ def sqlite_conn(sqlite_engine): yield sqlite_engine.connect() +@pytest.fixture +def sqlite_iris_str(sqlite_str, iris_path): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str) + create_and_load_iris(engine, iris_path, "sqlite") + engine.dispose() + return sqlite_str + + @pytest.fixture def sqlite_iris_engine(sqlite_engine, iris_path): create_and_load_iris(sqlite_engine, iris_path, "sqlite") @@ -485,11 +513,13 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path): sqlite_connectable = [ "sqlite_engine", "sqlite_conn", + "sqlite_str", ] sqlite_iris_connectable = [ "sqlite_iris_engine", "sqlite_iris_conn", + "sqlite_iris_str", ] sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable @@ -541,10 +571,47 @@ def test_to_sql_exist_fail(conn, test_frame1, request): @pytest.mark.db @pytest.mark.parametrize("conn", all_connectable_iris) -def test_read_iris(conn, request): +def test_read_iris_query(conn, request): conn = request.getfixturevalue(conn) - with pandasSQL_builder(conn) as pandasSQL: - iris_frame = pandasSQL.read_query("SELECT * FROM iris") + iris_frame = read_sql_query("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.db +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_chunksize(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7)) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.db +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = read_sql_table("iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("iris", conn) + check_iris_frame(iris_frame) + + +@pytest.mark.db +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table_chunksize(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_table("iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("iris", conn, chunksize=7)) check_iris_frame(iris_frame)