Skip to content

Commit 4033156

Browse files
cdcadmanChuck Cadman
and
Chuck Cadman
authored
BUG: Allow read_sql to work with chunksize. (#49967)
Co-authored-by: Chuck Cadman <[email protected]>
1 parent 07c3186 commit 4033156

File tree

2 files changed

+160
-69
lines changed

2 files changed

+160
-69
lines changed

pandas/io/sql.py

+81-57
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
ABC,
1010
abstractmethod,
1111
)
12-
from contextlib import contextmanager
12+
from contextlib import (
13+
ExitStack,
14+
contextmanager,
15+
)
1316
from datetime import (
1417
date,
1518
datetime,
@@ -71,6 +74,14 @@
7174
# -- Helper functions
7275

7376

77+
def _cleanup_after_generator(generator, exit_stack: ExitStack):
78+
"""Does the cleanup after iterating through the generator."""
79+
try:
80+
yield from generator
81+
finally:
82+
exit_stack.close()
83+
84+
7485
def _convert_params(sql, params):
7586
"""Convert SQL and params args to DBAPI2.0 compliant format."""
7687
args = [sql]
@@ -829,12 +840,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool:
829840
table_exists = has_table
830841

831842

832-
@contextmanager
833843
def pandasSQL_builder(
834844
con,
835845
schema: str | None = None,
836846
need_transaction: bool = False,
837-
) -> Iterator[PandasSQL]:
847+
) -> PandasSQL:
838848
"""
839849
Convenience function to return the correct PandasSQL subclass based on the
840850
provided parameters. Also creates a sqlalchemy connection and transaction
@@ -843,45 +853,24 @@ def pandasSQL_builder(
843853
import sqlite3
844854

845855
if isinstance(con, sqlite3.Connection) or con is None:
846-
yield SQLiteDatabase(con)
847-
else:
848-
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")
856+
return SQLiteDatabase(con)
849857

850-
if sqlalchemy is not None and isinstance(
851-
con, (str, sqlalchemy.engine.Connectable)
852-
):
853-
with _sqlalchemy_con(con, need_transaction) as con:
854-
yield SQLDatabase(con, schema=schema)
855-
elif isinstance(con, str) and sqlalchemy is None:
856-
raise ImportError("Using URI string without sqlalchemy installed.")
857-
else:
858-
859-
warnings.warn(
860-
"pandas only supports SQLAlchemy connectable (engine/connection) or "
861-
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
862-
"objects are not tested. Please consider using SQLAlchemy.",
863-
UserWarning,
864-
stacklevel=find_stack_level() + 2,
865-
)
866-
yield SQLiteDatabase(con)
858+
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")
867859

860+
if isinstance(con, str) and sqlalchemy is None:
861+
raise ImportError("Using URI string without sqlalchemy installed.")
868862

869-
@contextmanager
870-
def _sqlalchemy_con(connectable, need_transaction: bool):
871-
"""Create a sqlalchemy connection and a transaction if necessary."""
872-
sqlalchemy = import_optional_dependency("sqlalchemy", errors="raise")
863+
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):
864+
return SQLDatabase(con, schema, need_transaction)
873865

874-
if isinstance(connectable, str):
875-
connectable = sqlalchemy.create_engine(connectable)
876-
if isinstance(connectable, sqlalchemy.engine.Engine):
877-
with connectable.connect() as con:
878-
if need_transaction:
879-
with con.begin():
880-
yield con
881-
else:
882-
yield con
883-
else:
884-
yield connectable
866+
warnings.warn(
867+
"pandas only supports SQLAlchemy connectable (engine/connection) or "
868+
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
869+
"objects are not tested. Please consider using SQLAlchemy.",
870+
UserWarning,
871+
stacklevel=find_stack_level(),
872+
)
873+
return SQLiteDatabase(con)
885874

886875

887876
class SQLTable(PandasObject):
@@ -1106,6 +1095,7 @@ def _query_iterator(
11061095

11071096
def read(
11081097
self,
1098+
exit_stack: ExitStack,
11091099
coerce_float: bool = True,
11101100
parse_dates=None,
11111101
columns=None,
@@ -1126,13 +1116,16 @@ def read(
11261116
column_names = result.keys()
11271117

11281118
if chunksize is not None:
1129-
return self._query_iterator(
1130-
result,
1131-
chunksize,
1132-
column_names,
1133-
coerce_float=coerce_float,
1134-
parse_dates=parse_dates,
1135-
use_nullable_dtypes=use_nullable_dtypes,
1119+
return _cleanup_after_generator(
1120+
self._query_iterator(
1121+
result,
1122+
chunksize,
1123+
column_names,
1124+
coerce_float=coerce_float,
1125+
parse_dates=parse_dates,
1126+
use_nullable_dtypes=use_nullable_dtypes,
1127+
),
1128+
exit_stack,
11361129
)
11371130
else:
11381131
data = result.fetchall()
@@ -1384,6 +1377,12 @@ class PandasSQL(PandasObject, ABC):
13841377
Subclasses Should define read_query and to_sql.
13851378
"""
13861379

1380+
def __enter__(self):
1381+
return self
1382+
1383+
def __exit__(self, *args) -> None:
1384+
pass
1385+
13871386
def read_table(
13881387
self,
13891388
table_name: str,
@@ -1539,20 +1538,38 @@ class SQLDatabase(PandasSQL):
15391538
15401539
Parameters
15411540
----------
1542-
con : SQLAlchemy Connection
1543-
Connection to connect with the database. Using SQLAlchemy makes it
1541+
con : SQLAlchemy Connectable or URI string.
1542+
Connectable to connect with the database. Using SQLAlchemy makes it
15441543
possible to use any DB supported by that library.
15451544
schema : string, default None
15461545
Name of SQL schema in database to write to (if database flavor
15471546
supports this). If None, use default schema (default).
1547+
need_transaction : bool, default False
1548+
If True, SQLDatabase will create a transaction.
15481549
15491550
"""
15501551

1551-
def __init__(self, con, schema: str | None = None) -> None:
1552+
def __init__(
1553+
self, con, schema: str | None = None, need_transaction: bool = False
1554+
) -> None:
1555+
from sqlalchemy import create_engine
1556+
from sqlalchemy.engine import Engine
15521557
from sqlalchemy.schema import MetaData
15531558

1559+
self.exit_stack = ExitStack()
1560+
if isinstance(con, str):
1561+
con = create_engine(con)
1562+
if isinstance(con, Engine):
1563+
con = self.exit_stack.enter_context(con.connect())
1564+
if need_transaction:
1565+
self.exit_stack.enter_context(con.begin())
15541566
self.con = con
15551567
self.meta = MetaData(schema=schema)
1568+
self.returns_generator = False
1569+
1570+
def __exit__(self, *args) -> None:
1571+
if not self.returns_generator:
1572+
self.exit_stack.close()
15561573

15571574
@contextmanager
15581575
def run_transaction(self):
@@ -1623,7 +1640,10 @@ def read_table(
16231640
"""
16241641
self.meta.reflect(bind=self.con, only=[table_name])
16251642
table = SQLTable(table_name, self, index=index_col, schema=schema)
1643+
if chunksize is not None:
1644+
self.returns_generator = True
16261645
return table.read(
1646+
self.exit_stack,
16271647
coerce_float=coerce_float,
16281648
parse_dates=parse_dates,
16291649
columns=columns,
@@ -1733,15 +1753,19 @@ def read_query(
17331753
columns = result.keys()
17341754

17351755
if chunksize is not None:
1736-
return self._query_iterator(
1737-
result,
1738-
chunksize,
1739-
columns,
1740-
index_col=index_col,
1741-
coerce_float=coerce_float,
1742-
parse_dates=parse_dates,
1743-
dtype=dtype,
1744-
use_nullable_dtypes=use_nullable_dtypes,
1756+
self.returns_generator = True
1757+
return _cleanup_after_generator(
1758+
self._query_iterator(
1759+
result,
1760+
chunksize,
1761+
columns,
1762+
index_col=index_col,
1763+
coerce_float=coerce_float,
1764+
parse_dates=parse_dates,
1765+
dtype=dtype,
1766+
use_nullable_dtypes=use_nullable_dtypes,
1767+
),
1768+
self.exit_stack,
17451769
)
17461770
else:
17471771
data = result.fetchall()

pandas/tests/io/test_sql.py

+79-12
Original file line numberDiff line numberDiff line change
@@ -260,24 +260,34 @@ def check_iris_frame(frame: DataFrame):
260260
row = frame.iloc[0]
261261
assert issubclass(pytype, np.floating)
262262
tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
263+
assert frame.shape in ((150, 5), (8, 5))
263264

264265

265266
def count_rows(conn, table_name: str):
266267
stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
267268
if isinstance(conn, sqlite3.Connection):
268269
cur = conn.cursor()
269-
result = cur.execute(stmt)
270+
return cur.execute(stmt).fetchone()[0]
270271
else:
271-
from sqlalchemy import text
272+
from sqlalchemy import (
273+
create_engine,
274+
text,
275+
)
272276
from sqlalchemy.engine import Engine
273277

274278
stmt = text(stmt)
275-
if isinstance(conn, Engine):
279+
if isinstance(conn, str):
280+
try:
281+
engine = create_engine(conn)
282+
with engine.connect() as conn:
283+
return conn.execute(stmt).scalar_one()
284+
finally:
285+
engine.dispose()
286+
elif isinstance(conn, Engine):
276287
with conn.connect() as conn:
277-
result = conn.execute(stmt)
288+
return conn.execute(stmt).scalar_one()
278289
else:
279-
result = conn.execute(stmt)
280-
return result.fetchone()[0]
290+
return conn.execute(stmt).scalar_one()
281291

282292

283293
@pytest.fixture
@@ -388,6 +398,7 @@ def mysql_pymysql_engine(iris_path, types_data):
388398
engine = sqlalchemy.create_engine(
389399
"mysql+pymysql://root@localhost:3306/pandas",
390400
connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
401+
poolclass=sqlalchemy.pool.NullPool,
391402
)
392403
insp = sqlalchemy.inspect(engine)
393404
if not insp.has_table("iris"):
@@ -414,7 +425,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):
414425
sqlalchemy = pytest.importorskip("sqlalchemy")
415426
pytest.importorskip("psycopg2")
416427
engine = sqlalchemy.create_engine(
417-
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas"
428+
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas",
429+
poolclass=sqlalchemy.pool.NullPool,
418430
)
419431
insp = sqlalchemy.inspect(engine)
420432
if not insp.has_table("iris"):
@@ -435,9 +447,16 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
435447

436448

437449
@pytest.fixture
438-
def sqlite_engine():
450+
def sqlite_str():
451+
pytest.importorskip("sqlalchemy")
452+
with tm.ensure_clean() as name:
453+
yield "sqlite:///" + name
454+
455+
456+
@pytest.fixture
457+
def sqlite_engine(sqlite_str):
439458
sqlalchemy = pytest.importorskip("sqlalchemy")
440-
engine = sqlalchemy.create_engine("sqlite://")
459+
engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool)
441460
yield engine
442461
engine.dispose()
443462

@@ -447,6 +466,15 @@ def sqlite_conn(sqlite_engine):
447466
yield sqlite_engine.connect()
448467

449468

469+
@pytest.fixture
470+
def sqlite_iris_str(sqlite_str, iris_path):
471+
sqlalchemy = pytest.importorskip("sqlalchemy")
472+
engine = sqlalchemy.create_engine(sqlite_str)
473+
create_and_load_iris(engine, iris_path, "sqlite")
474+
engine.dispose()
475+
return sqlite_str
476+
477+
450478
@pytest.fixture
451479
def sqlite_iris_engine(sqlite_engine, iris_path):
452480
create_and_load_iris(sqlite_engine, iris_path, "sqlite")
@@ -485,11 +513,13 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
485513
sqlite_connectable = [
486514
"sqlite_engine",
487515
"sqlite_conn",
516+
"sqlite_str",
488517
]
489518

490519
sqlite_iris_connectable = [
491520
"sqlite_iris_engine",
492521
"sqlite_iris_conn",
522+
"sqlite_iris_str",
493523
]
494524

495525
sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable
@@ -541,10 +571,47 @@ def test_to_sql_exist_fail(conn, test_frame1, request):
541571

542572
@pytest.mark.db
543573
@pytest.mark.parametrize("conn", all_connectable_iris)
544-
def test_read_iris(conn, request):
574+
def test_read_iris_query(conn, request):
545575
conn = request.getfixturevalue(conn)
546-
with pandasSQL_builder(conn) as pandasSQL:
547-
iris_frame = pandasSQL.read_query("SELECT * FROM iris")
576+
iris_frame = read_sql_query("SELECT * FROM iris", conn)
577+
check_iris_frame(iris_frame)
578+
iris_frame = pd.read_sql("SELECT * FROM iris", conn)
579+
check_iris_frame(iris_frame)
580+
iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn)
581+
assert iris_frame.shape == (0, 5)
582+
assert "SepalWidth" in iris_frame.columns
583+
584+
585+
@pytest.mark.db
586+
@pytest.mark.parametrize("conn", all_connectable_iris)
587+
def test_read_iris_query_chunksize(conn, request):
588+
conn = request.getfixturevalue(conn)
589+
iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7))
590+
check_iris_frame(iris_frame)
591+
iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7))
592+
check_iris_frame(iris_frame)
593+
iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7))
594+
assert iris_frame.shape == (0, 5)
595+
assert "SepalWidth" in iris_frame.columns
596+
597+
598+
@pytest.mark.db
599+
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
600+
def test_read_iris_table(conn, request):
601+
conn = request.getfixturevalue(conn)
602+
iris_frame = read_sql_table("iris", conn)
603+
check_iris_frame(iris_frame)
604+
iris_frame = pd.read_sql("iris", conn)
605+
check_iris_frame(iris_frame)
606+
607+
608+
@pytest.mark.db
609+
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
610+
def test_read_iris_table_chunksize(conn, request):
611+
conn = request.getfixturevalue(conn)
612+
iris_frame = concat(read_sql_table("iris", conn, chunksize=7))
613+
check_iris_frame(iris_frame)
614+
iris_frame = concat(pd.read_sql("iris", conn, chunksize=7))
548615
check_iris_frame(iris_frame)
549616

550617

0 commit comments

Comments
 (0)