Skip to content

BUG: Allow read_sql to work with chunksize. #49967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 81 additions & 57 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
ABC,
abstractmethod,
)
from contextlib import contextmanager
from contextlib import (
ExitStack,
contextmanager,
)
from datetime import (
date,
datetime,
Expand Down Expand Up @@ -69,6 +72,14 @@
# -- Helper functions


def _cleanup_after_generator(generator, exit_stack: ExitStack):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might have slightly preferred if _query_iterator used self.exit_stack directly in the functions, but could be a follow up e.g.

def _query_iterator(...):
    with self.exit_stack():
        ...

"""Does the cleanup after iterating through the generator."""
try:
yield from generator
finally:
exit_stack.close()


def _convert_params(sql, params):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
Expand Down Expand Up @@ -772,12 +783,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool:
table_exists = has_table


@contextmanager
def pandasSQL_builder(
con,
schema: str | None = None,
need_transaction: bool = False,
) -> Iterator[PandasSQL]:
) -> PandasSQL:
"""
Convenience function to return the correct PandasSQL subclass based on the
provided parameters. Also creates a sqlalchemy connection and transaction
Expand All @@ -786,45 +796,24 @@ def pandasSQL_builder(
import sqlite3

if isinstance(con, sqlite3.Connection) or con is None:
yield SQLiteDatabase(con)
else:
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")
return SQLiteDatabase(con)

if sqlalchemy is not None and isinstance(
con, (str, sqlalchemy.engine.Connectable)
):
with _sqlalchemy_con(con, need_transaction) as con:
yield SQLDatabase(con, schema=schema)
elif isinstance(con, str) and sqlalchemy is None:
raise ImportError("Using URI string without sqlalchemy installed.")
else:

warnings.warn(
"pandas only supports SQLAlchemy connectable (engine/connection) or "
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
"objects are not tested. Please consider using SQLAlchemy.",
UserWarning,
stacklevel=find_stack_level() + 2,
)
yield SQLiteDatabase(con)
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")

if isinstance(con, str) and sqlalchemy is None:
raise ImportError("Using URI string without sqlalchemy installed.")

@contextmanager
def _sqlalchemy_con(connectable, need_transaction: bool):
"""Create a sqlalchemy connection and a transaction if necessary."""
sqlalchemy = import_optional_dependency("sqlalchemy", errors="raise")
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):
return SQLDatabase(con, schema, need_transaction)

if isinstance(connectable, str):
connectable = sqlalchemy.create_engine(connectable)
if isinstance(connectable, sqlalchemy.engine.Engine):
with connectable.connect() as con:
if need_transaction:
with con.begin():
yield con
else:
yield con
else:
yield connectable
warnings.warn(
"pandas only supports SQLAlchemy connectable (engine/connection) or "
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
"objects are not tested. Please consider using SQLAlchemy.",
UserWarning,
stacklevel=find_stack_level(),
)
return SQLiteDatabase(con)


class SQLTable(PandasObject):
Expand Down Expand Up @@ -1049,6 +1038,7 @@ def _query_iterator(

def read(
self,
exit_stack: ExitStack,
coerce_float: bool = True,
parse_dates=None,
columns=None,
Expand All @@ -1069,13 +1059,16 @@ def read(
column_names = result.keys()

if chunksize is not None:
return self._query_iterator(
result,
chunksize,
column_names,
coerce_float=coerce_float,
parse_dates=parse_dates,
use_nullable_dtypes=use_nullable_dtypes,
return _cleanup_after_generator(
self._query_iterator(
result,
chunksize,
column_names,
coerce_float=coerce_float,
parse_dates=parse_dates,
use_nullable_dtypes=use_nullable_dtypes,
),
exit_stack,
)
else:
data = result.fetchall()
Expand Down Expand Up @@ -1327,6 +1320,12 @@ class PandasSQL(PandasObject, ABC):
Subclasses Should define read_query and to_sql.
"""

def __enter__(self):
return self

def __exit__(self, *args) -> None:
pass

def read_table(
self,
table_name: str,
Expand Down Expand Up @@ -1482,20 +1481,38 @@ class SQLDatabase(PandasSQL):

Parameters
----------
con : SQLAlchemy Connection
Connection to connect with the database. Using SQLAlchemy makes it
con : SQLAlchemy Connectable or URI string.
Connectable to connect with the database. Using SQLAlchemy makes it
possible to use any DB supported by that library.
schema : string, default None
Name of SQL schema in database to write to (if database flavor
supports this). If None, use default schema (default).
need_transaction : bool, default False
If True, SQLDatabase will create a transaction.

"""

def __init__(self, con, schema: str | None = None) -> None:
def __init__(
self, con, schema: str | None = None, need_transaction: bool = False
) -> None:
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.schema import MetaData

self.exit_stack = ExitStack()
if isinstance(con, str):
con = create_engine(con)
if isinstance(con, Engine):
con = self.exit_stack.enter_context(con.connect())
if need_transaction:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance that this is related to the autocommit warnings we are getting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this PR impacts the autocommit warnings, but I was able to resolve them in #48576 with this commit: 46a6e75 .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this makes sense. If we get a connection that was not begun yet, we see the warning since the implicit autocommit handled this previously

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and in case the connection is already in a transaction, my other PR modifies the line if need_transaction: so as not to call conn.begin() a second time: 848dc71

self.exit_stack.enter_context(con.begin())
self.con = con
self.meta = MetaData(schema=schema)
self.returns_generator = False

def __exit__(self, *args) -> None:
if not self.returns_generator:
self.exit_stack.close()

@contextmanager
def run_transaction(self):
Expand Down Expand Up @@ -1566,7 +1583,10 @@ def read_table(
"""
self.meta.reflect(bind=self.con, only=[table_name])
table = SQLTable(table_name, self, index=index_col, schema=schema)
if chunksize is not None:
self.returns_generator = True
return table.read(
self.exit_stack,
coerce_float=coerce_float,
parse_dates=parse_dates,
columns=columns,
Expand Down Expand Up @@ -1675,15 +1695,19 @@ def read_query(
columns = result.keys()

if chunksize is not None:
return self._query_iterator(
result,
chunksize,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
use_nullable_dtypes=use_nullable_dtypes,
self.returns_generator = True
return _cleanup_after_generator(
self._query_iterator(
result,
chunksize,
columns,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
use_nullable_dtypes=use_nullable_dtypes,
),
self.exit_stack,
)
else:
data = result.fetchall()
Expand Down
91 changes: 79 additions & 12 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,24 +260,34 @@ def check_iris_frame(frame: DataFrame):
row = frame.iloc[0]
assert issubclass(pytype, np.floating)
tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
assert frame.shape in ((150, 5), (8, 5))


def count_rows(conn, table_name: str):
stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
if isinstance(conn, sqlite3.Connection):
cur = conn.cursor()
result = cur.execute(stmt)
return cur.execute(stmt).fetchone()[0]
else:
from sqlalchemy import text
from sqlalchemy import (
create_engine,
text,
)
from sqlalchemy.engine import Engine

stmt = text(stmt)
if isinstance(conn, Engine):
if isinstance(conn, str):
try:
engine = create_engine(conn)
with engine.connect() as conn:
return conn.execute(stmt).scalar_one()
finally:
engine.dispose()
elif isinstance(conn, Engine):
with conn.connect() as conn:
result = conn.execute(stmt)
return conn.execute(stmt).scalar_one()
else:
result = conn.execute(stmt)
return result.fetchone()[0]
return conn.execute(stmt).scalar_one()


@pytest.fixture
Expand Down Expand Up @@ -388,6 +398,7 @@ def mysql_pymysql_engine(iris_path, types_data):
engine = sqlalchemy.create_engine(
"mysql+pymysql://root@localhost:3306/pandas",
connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
poolclass=sqlalchemy.pool.NullPool,
)
insp = sqlalchemy.inspect(engine)
if not insp.has_table("iris"):
Expand All @@ -414,7 +425,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):
sqlalchemy = pytest.importorskip("sqlalchemy")
pytest.importorskip("psycopg2")
engine = sqlalchemy.create_engine(
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas"
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas",
poolclass=sqlalchemy.pool.NullPool,
)
insp = sqlalchemy.inspect(engine)
if not insp.has_table("iris"):
Expand All @@ -435,9 +447,16 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):


@pytest.fixture
def sqlite_engine():
def sqlite_str():
pytest.importorskip("sqlalchemy")
with tm.ensure_clean() as name:
yield "sqlite:///" + name


@pytest.fixture
def sqlite_engine(sqlite_str):
sqlalchemy = pytest.importorskip("sqlalchemy")
engine = sqlalchemy.create_engine("sqlite://")
engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool)
yield engine
engine.dispose()

Expand All @@ -447,6 +466,15 @@ def sqlite_conn(sqlite_engine):
yield sqlite_engine.connect()


@pytest.fixture
def sqlite_iris_str(sqlite_str, iris_path):
sqlalchemy = pytest.importorskip("sqlalchemy")
engine = sqlalchemy.create_engine(sqlite_str)
create_and_load_iris(engine, iris_path, "sqlite")
engine.dispose()
return sqlite_str


@pytest.fixture
def sqlite_iris_engine(sqlite_engine, iris_path):
create_and_load_iris(sqlite_engine, iris_path, "sqlite")
Expand Down Expand Up @@ -485,11 +513,13 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
sqlite_connectable = [
"sqlite_engine",
"sqlite_conn",
"sqlite_str",
]

sqlite_iris_connectable = [
"sqlite_iris_engine",
"sqlite_iris_conn",
"sqlite_iris_str",
]

sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable
Expand Down Expand Up @@ -541,10 +571,47 @@ def test_to_sql_exist_fail(conn, test_frame1, request):

@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable_iris)
def test_read_iris(conn, request):
def test_read_iris_query(conn, request):
conn = request.getfixturevalue(conn)
with pandasSQL_builder(conn) as pandasSQL:
iris_frame = pandasSQL.read_query("SELECT * FROM iris")
iris_frame = read_sql_query("SELECT * FROM iris", conn)
check_iris_frame(iris_frame)
iris_frame = pd.read_sql("SELECT * FROM iris", conn)
check_iris_frame(iris_frame)
iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn)
assert iris_frame.shape == (0, 5)
assert "SepalWidth" in iris_frame.columns


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable_iris)
def test_read_iris_query_chunksize(conn, request):
conn = request.getfixturevalue(conn)
iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7))
check_iris_frame(iris_frame)
iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7))
check_iris_frame(iris_frame)
iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7))
assert iris_frame.shape == (0, 5)
assert "SepalWidth" in iris_frame.columns


@pytest.mark.db
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
def test_read_iris_table(conn, request):
conn = request.getfixturevalue(conn)
iris_frame = read_sql_table("iris", conn)
check_iris_frame(iris_frame)
iris_frame = pd.read_sql("iris", conn)
check_iris_frame(iris_frame)


@pytest.mark.db
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
def test_read_iris_table_chunksize(conn, request):
conn = request.getfixturevalue(conn)
iris_frame = concat(read_sql_table("iris", conn, chunksize=7))
check_iris_frame(iris_frame)
iris_frame = concat(pd.read_sql("iris", conn, chunksize=7))
check_iris_frame(iris_frame)


Expand Down