Skip to content

Commit b87dc7c

Browse files
author
Chuck Cadman
committed
BUG: Allow read_sql to work with chunksize.
1 parent 53c1425 commit b87dc7c

File tree

2 files changed

+160
-69
lines changed

2 files changed

+160
-69
lines changed

pandas/io/sql.py

+81-57
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
ABC,
1010
abstractmethod,
1111
)
12-
from contextlib import contextmanager
12+
from contextlib import (
13+
ExitStack,
14+
contextmanager,
15+
)
1316
from datetime import (
1417
date,
1518
datetime,
@@ -69,6 +72,14 @@
6972
# -- Helper functions
7073

7174

75+
def _cleanup_after_generator(generator, exit_stack: ExitStack):
76+
"""Does the cleanup after iterating through the generator."""
77+
try:
78+
yield from generator
79+
finally:
80+
exit_stack.close()
81+
82+
7283
def _convert_params(sql, params):
7384
"""Convert SQL and params args to DBAPI2.0 compliant format."""
7485
args = [sql]
@@ -792,12 +803,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool:
792803
table_exists = has_table
793804

794805

795-
@contextmanager
796806
def pandasSQL_builder(
797807
con,
798808
schema: str | None = None,
799809
need_transaction: bool = False,
800-
) -> Iterator[PandasSQL]:
810+
) -> PandasSQL:
801811
"""
802812
Convenience function to return the correct PandasSQL subclass based on the
803813
provided parameters. Also creates a sqlalchemy connection and transaction
@@ -806,45 +816,24 @@ def pandasSQL_builder(
806816
import sqlite3
807817

808818
if isinstance(con, sqlite3.Connection) or con is None:
809-
yield SQLiteDatabase(con)
810-
else:
811-
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")
819+
return SQLiteDatabase(con)
812820

813-
if sqlalchemy is not None and isinstance(
814-
con, (str, sqlalchemy.engine.Connectable)
815-
):
816-
with _sqlalchemy_con(con, need_transaction) as con:
817-
yield SQLDatabase(con, schema=schema)
818-
elif isinstance(con, str) and sqlalchemy is None:
819-
raise ImportError("Using URI string without sqlalchemy installed.")
820-
else:
821-
822-
warnings.warn(
823-
"pandas only supports SQLAlchemy connectable (engine/connection) or "
824-
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
825-
"objects are not tested. Please consider using SQLAlchemy.",
826-
UserWarning,
827-
stacklevel=find_stack_level() + 2,
828-
)
829-
yield SQLiteDatabase(con)
821+
sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore")
830822

823+
if isinstance(con, str) and sqlalchemy is None:
824+
raise ImportError("Using URI string without sqlalchemy installed.")
831825

832-
@contextmanager
833-
def _sqlalchemy_con(connectable, need_transaction: bool):
834-
"""Create a sqlalchemy connection and a transaction if necessary."""
835-
sqlalchemy = import_optional_dependency("sqlalchemy", errors="raise")
826+
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):
827+
return SQLDatabase(con, schema, need_transaction)
836828

837-
if isinstance(connectable, str):
838-
connectable = sqlalchemy.create_engine(connectable)
839-
if isinstance(connectable, sqlalchemy.engine.Engine):
840-
with connectable.connect() as con:
841-
if need_transaction:
842-
with con.begin():
843-
yield con
844-
else:
845-
yield con
846-
else:
847-
yield connectable
829+
warnings.warn(
830+
"pandas only supports SQLAlchemy connectable (engine/connection) or "
831+
"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
832+
"objects are not tested. Please consider using SQLAlchemy.",
833+
UserWarning,
834+
stacklevel=find_stack_level(),
835+
)
836+
return SQLiteDatabase(con)
848837

849838

850839
class SQLTable(PandasObject):
@@ -1069,6 +1058,7 @@ def _query_iterator(
10691058

10701059
def read(
10711060
self,
1061+
exit_stack: ExitStack,
10721062
coerce_float: bool = True,
10731063
parse_dates=None,
10741064
columns=None,
@@ -1089,13 +1079,16 @@ def read(
10891079
column_names = result.keys()
10901080

10911081
if chunksize is not None:
1092-
return self._query_iterator(
1093-
result,
1094-
chunksize,
1095-
column_names,
1096-
coerce_float=coerce_float,
1097-
parse_dates=parse_dates,
1098-
use_nullable_dtypes=use_nullable_dtypes,
1082+
return _cleanup_after_generator(
1083+
self._query_iterator(
1084+
result,
1085+
chunksize,
1086+
column_names,
1087+
coerce_float=coerce_float,
1088+
parse_dates=parse_dates,
1089+
use_nullable_dtypes=use_nullable_dtypes,
1090+
),
1091+
exit_stack,
10991092
)
11001093
else:
11011094
data = result.fetchall()
@@ -1347,6 +1340,12 @@ class PandasSQL(PandasObject, ABC):
13471340
Subclasses Should define read_query and to_sql.
13481341
"""
13491342

1343+
def __enter__(self):
1344+
return self
1345+
1346+
def __exit__(self, *args) -> None:
1347+
pass
1348+
13501349
def read_table(
13511350
self,
13521351
table_name: str,
@@ -1502,20 +1501,38 @@ class SQLDatabase(PandasSQL):
15021501
15031502
Parameters
15041503
----------
1505-
con : SQLAlchemy Connection
1506-
Connection to connect with the database. Using SQLAlchemy makes it
1504+
con : SQLAlchemy Connectable or URI string.
1505+
Connectable to connect with the database. Using SQLAlchemy makes it
15071506
possible to use any DB supported by that library.
15081507
schema : string, default None
15091508
Name of SQL schema in database to write to (if database flavor
15101509
supports this). If None, use default schema (default).
1510+
need_transaction : bool, default False
1511+
If True, SQLDatabase will create a transaction.
15111512
15121513
"""
15131514

1514-
def __init__(self, con, schema: str | None = None) -> None:
1515+
def __init__(
1516+
self, con, schema: str | None = None, need_transaction: bool = False
1517+
) -> None:
1518+
from sqlalchemy import create_engine
1519+
from sqlalchemy.engine import Engine
15151520
from sqlalchemy.schema import MetaData
15161521

1522+
self.exit_stack = ExitStack()
1523+
if isinstance(con, str):
1524+
con = create_engine(con)
1525+
if isinstance(con, Engine):
1526+
con = self.exit_stack.enter_context(con.connect())
1527+
if need_transaction:
1528+
self.exit_stack.enter_context(con.begin())
15171529
self.con = con
15181530
self.meta = MetaData(schema=schema)
1531+
self.returns_generator = False
1532+
1533+
def __exit__(self, *args) -> None:
1534+
if not self.returns_generator:
1535+
self.exit_stack.close()
15191536

15201537
@contextmanager
15211538
def run_transaction(self):
@@ -1586,7 +1603,10 @@ def read_table(
15861603
"""
15871604
self.meta.reflect(bind=self.con, only=[table_name])
15881605
table = SQLTable(table_name, self, index=index_col, schema=schema)
1606+
if chunksize is not None:
1607+
self.returns_generator = True
15891608
return table.read(
1609+
self.exit_stack,
15901610
coerce_float=coerce_float,
15911611
parse_dates=parse_dates,
15921612
columns=columns,
@@ -1696,15 +1716,19 @@ def read_query(
16961716
columns = result.keys()
16971717

16981718
if chunksize is not None:
1699-
return self._query_iterator(
1700-
result,
1701-
chunksize,
1702-
columns,
1703-
index_col=index_col,
1704-
coerce_float=coerce_float,
1705-
parse_dates=parse_dates,
1706-
dtype=dtype,
1707-
use_nullable_dtypes=use_nullable_dtypes,
1719+
self.returns_generator = True
1720+
return _cleanup_after_generator(
1721+
self._query_iterator(
1722+
result,
1723+
chunksize,
1724+
columns,
1725+
index_col=index_col,
1726+
coerce_float=coerce_float,
1727+
parse_dates=parse_dates,
1728+
dtype=dtype,
1729+
use_nullable_dtypes=use_nullable_dtypes,
1730+
),
1731+
self.exit_stack,
17081732
)
17091733
else:
17101734
data = result.fetchall()

pandas/tests/io/test_sql.py

+79-12
Original file line numberDiff line numberDiff line change
@@ -260,24 +260,34 @@ def check_iris_frame(frame: DataFrame):
260260
row = frame.iloc[0]
261261
assert issubclass(pytype, np.floating)
262262
tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
263+
assert frame.shape in ((150, 5), (8, 5))
263264

264265

265266
def count_rows(conn, table_name: str):
266267
stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
267268
if isinstance(conn, sqlite3.Connection):
268269
cur = conn.cursor()
269-
result = cur.execute(stmt)
270+
return cur.execute(stmt).fetchone()[0]
270271
else:
271-
from sqlalchemy import text
272+
from sqlalchemy import (
273+
create_engine,
274+
text,
275+
)
272276
from sqlalchemy.engine import Engine
273277

274278
stmt = text(stmt)
275-
if isinstance(conn, Engine):
279+
if isinstance(conn, str):
280+
try:
281+
engine = create_engine(conn)
282+
with engine.connect() as conn:
283+
return conn.execute(stmt).scalar_one()
284+
finally:
285+
engine.dispose()
286+
elif isinstance(conn, Engine):
276287
with conn.connect() as conn:
277-
result = conn.execute(stmt)
288+
return conn.execute(stmt).scalar_one()
278289
else:
279-
result = conn.execute(stmt)
280-
return result.fetchone()[0]
290+
return conn.execute(stmt).scalar_one()
281291

282292

283293
@pytest.fixture
@@ -388,6 +398,7 @@ def mysql_pymysql_engine(iris_path, types_data):
388398
engine = sqlalchemy.create_engine(
389399
"mysql+pymysql://root@localhost:3306/pandas",
390400
connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
401+
poolclass=sqlalchemy.pool.NullPool,
391402
)
392403
insp = sqlalchemy.inspect(engine)
393404
if not insp.has_table("iris"):
@@ -414,7 +425,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):
414425
sqlalchemy = pytest.importorskip("sqlalchemy")
415426
pytest.importorskip("psycopg2")
416427
engine = sqlalchemy.create_engine(
417-
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas"
428+
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas",
429+
poolclass=sqlalchemy.pool.NullPool,
418430
)
419431
insp = sqlalchemy.inspect(engine)
420432
if not insp.has_table("iris"):
@@ -435,9 +447,16 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
435447

436448

437449
@pytest.fixture
438-
def sqlite_engine():
450+
def sqlite_str():
451+
pytest.importorskip("sqlalchemy")
452+
with tm.ensure_clean() as name:
453+
yield "sqlite:///" + name
454+
455+
456+
@pytest.fixture
457+
def sqlite_engine(sqlite_str):
439458
sqlalchemy = pytest.importorskip("sqlalchemy")
440-
engine = sqlalchemy.create_engine("sqlite://")
459+
engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool)
441460
yield engine
442461
engine.dispose()
443462

@@ -447,6 +466,15 @@ def sqlite_conn(sqlite_engine):
447466
yield sqlite_engine.connect()
448467

449468

469+
@pytest.fixture
470+
def sqlite_iris_str(sqlite_str, iris_path):
471+
sqlalchemy = pytest.importorskip("sqlalchemy")
472+
engine = sqlalchemy.create_engine(sqlite_str)
473+
create_and_load_iris(engine, iris_path, "sqlite")
474+
engine.dispose()
475+
return sqlite_str
476+
477+
450478
@pytest.fixture
451479
def sqlite_iris_engine(sqlite_engine, iris_path):
452480
create_and_load_iris(sqlite_engine, iris_path, "sqlite")
@@ -485,11 +513,13 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
485513
sqlite_connectable = [
486514
"sqlite_engine",
487515
"sqlite_conn",
516+
"sqlite_str",
488517
]
489518

490519
sqlite_iris_connectable = [
491520
"sqlite_iris_engine",
492521
"sqlite_iris_conn",
522+
"sqlite_iris_str",
493523
]
494524

495525
sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable
@@ -541,10 +571,47 @@ def test_to_sql_exist_fail(conn, test_frame1, request):
541571

542572
@pytest.mark.db
543573
@pytest.mark.parametrize("conn", all_connectable_iris)
544-
def test_read_iris(conn, request):
574+
def test_read_iris_query(conn, request):
545575
conn = request.getfixturevalue(conn)
546-
with pandasSQL_builder(conn) as pandasSQL:
547-
iris_frame = pandasSQL.read_query("SELECT * FROM iris")
576+
iris_frame = read_sql_query("SELECT * FROM iris", conn)
577+
check_iris_frame(iris_frame)
578+
iris_frame = pd.read_sql("SELECT * FROM iris", conn)
579+
check_iris_frame(iris_frame)
580+
iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn)
581+
assert iris_frame.shape == (0, 5)
582+
assert "SepalWidth" in iris_frame.columns
583+
584+
585+
@pytest.mark.db
586+
@pytest.mark.parametrize("conn", all_connectable_iris)
587+
def test_read_iris_query_chunksize(conn, request):
588+
conn = request.getfixturevalue(conn)
589+
iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7))
590+
check_iris_frame(iris_frame)
591+
iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7))
592+
check_iris_frame(iris_frame)
593+
iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7))
594+
assert iris_frame.shape == (0, 5)
595+
assert "SepalWidth" in iris_frame.columns
596+
597+
598+
@pytest.mark.db
599+
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
600+
def test_read_iris_table(conn, request):
601+
conn = request.getfixturevalue(conn)
602+
iris_frame = read_sql_table("iris", conn)
603+
check_iris_frame(iris_frame)
604+
iris_frame = pd.read_sql("iris", conn)
605+
check_iris_frame(iris_frame)
606+
607+
608+
@pytest.mark.db
609+
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
610+
def test_read_iris_table_chunksize(conn, request):
611+
conn = request.getfixturevalue(conn)
612+
iris_frame = concat(read_sql_table("iris", conn, chunksize=7))
613+
check_iris_frame(iris_frame)
614+
iris_frame = concat(pd.read_sql("iris", conn, chunksize=7))
548615
check_iris_frame(iris_frame)
549616

550617

0 commit comments

Comments
 (0)