9
9
ABC ,
10
10
abstractmethod ,
11
11
)
12
- from contextlib import contextmanager
12
+ from contextlib import (
13
+ ExitStack ,
14
+ contextmanager ,
15
+ )
13
16
from datetime import (
14
17
date ,
15
18
datetime ,
71
74
# -- Helper functions
72
75
73
76
77
+ def _cleanup_after_generator (generator , exit_stack : ExitStack ):
78
+ """Does the cleanup after iterating through the generator."""
79
+ try :
80
+ yield from generator
81
+ finally :
82
+ exit_stack .close ()
83
+
84
+
74
85
def _convert_params (sql , params ):
75
86
"""Convert SQL and params args to DBAPI2.0 compliant format."""
76
87
args = [sql ]
@@ -829,12 +840,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool:
829
840
table_exists = has_table
830
841
831
842
832
- @contextmanager
833
843
def pandasSQL_builder (
834
844
con ,
835
845
schema : str | None = None ,
836
846
need_transaction : bool = False ,
837
- ) -> Iterator [ PandasSQL ] :
847
+ ) -> PandasSQL :
838
848
"""
839
849
Convenience function to return the correct PandasSQL subclass based on the
840
850
provided parameters. Also creates a sqlalchemy connection and transaction
@@ -843,45 +853,24 @@ def pandasSQL_builder(
843
853
import sqlite3
844
854
845
855
if isinstance (con , sqlite3 .Connection ) or con is None :
846
- yield SQLiteDatabase (con )
847
- else :
848
- sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "ignore" )
856
+ return SQLiteDatabase (con )
849
857
850
- if sqlalchemy is not None and isinstance (
851
- con , (str , sqlalchemy .engine .Connectable )
852
- ):
853
- with _sqlalchemy_con (con , need_transaction ) as con :
854
- yield SQLDatabase (con , schema = schema )
855
- elif isinstance (con , str ) and sqlalchemy is None :
856
- raise ImportError ("Using URI string without sqlalchemy installed." )
857
- else :
858
-
859
- warnings .warn (
860
- "pandas only supports SQLAlchemy connectable (engine/connection) or "
861
- "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
862
- "objects are not tested. Please consider using SQLAlchemy." ,
863
- UserWarning ,
864
- stacklevel = find_stack_level () + 2 ,
865
- )
866
- yield SQLiteDatabase (con )
858
+ sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "ignore" )
867
859
860
+ if isinstance (con , str ) and sqlalchemy is None :
861
+ raise ImportError ("Using URI string without sqlalchemy installed." )
868
862
869
- @contextmanager
870
- def _sqlalchemy_con (connectable , need_transaction : bool ):
871
- """Create a sqlalchemy connection and a transaction if necessary."""
872
- sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "raise" )
863
+ if sqlalchemy is not None and isinstance (con , (str , sqlalchemy .engine .Connectable )):
864
+ return SQLDatabase (con , schema , need_transaction )
873
865
874
- if isinstance (connectable , str ):
875
- connectable = sqlalchemy .create_engine (connectable )
876
- if isinstance (connectable , sqlalchemy .engine .Engine ):
877
- with connectable .connect () as con :
878
- if need_transaction :
879
- with con .begin ():
880
- yield con
881
- else :
882
- yield con
883
- else :
884
- yield connectable
866
+ warnings .warn (
867
+ "pandas only supports SQLAlchemy connectable (engine/connection) or "
868
+ "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
869
+ "objects are not tested. Please consider using SQLAlchemy." ,
870
+ UserWarning ,
871
+ stacklevel = find_stack_level (),
872
+ )
873
+ return SQLiteDatabase (con )
885
874
886
875
887
876
class SQLTable (PandasObject ):
@@ -1106,6 +1095,7 @@ def _query_iterator(
1106
1095
1107
1096
def read (
1108
1097
self ,
1098
+ exit_stack : ExitStack ,
1109
1099
coerce_float : bool = True ,
1110
1100
parse_dates = None ,
1111
1101
columns = None ,
@@ -1126,13 +1116,16 @@ def read(
1126
1116
column_names = result .keys ()
1127
1117
1128
1118
if chunksize is not None :
1129
- return self ._query_iterator (
1130
- result ,
1131
- chunksize ,
1132
- column_names ,
1133
- coerce_float = coerce_float ,
1134
- parse_dates = parse_dates ,
1135
- use_nullable_dtypes = use_nullable_dtypes ,
1119
+ return _cleanup_after_generator (
1120
+ self ._query_iterator (
1121
+ result ,
1122
+ chunksize ,
1123
+ column_names ,
1124
+ coerce_float = coerce_float ,
1125
+ parse_dates = parse_dates ,
1126
+ use_nullable_dtypes = use_nullable_dtypes ,
1127
+ ),
1128
+ exit_stack ,
1136
1129
)
1137
1130
else :
1138
1131
data = result .fetchall ()
@@ -1384,6 +1377,12 @@ class PandasSQL(PandasObject, ABC):
1384
1377
Subclasses Should define read_query and to_sql.
1385
1378
"""
1386
1379
1380
+ def __enter__ (self ):
1381
+ return self
1382
+
1383
+ def __exit__ (self , * args ) -> None :
1384
+ pass
1385
+
1387
1386
def read_table (
1388
1387
self ,
1389
1388
table_name : str ,
@@ -1539,20 +1538,38 @@ class SQLDatabase(PandasSQL):
1539
1538
1540
1539
Parameters
1541
1540
----------
1542
- con : SQLAlchemy Connection
1543
- Connection to connect with the database. Using SQLAlchemy makes it
1541
+ con : SQLAlchemy Connectable or URI string.
1542
+ Connectable to connect with the database. Using SQLAlchemy makes it
1544
1543
possible to use any DB supported by that library.
1545
1544
schema : string, default None
1546
1545
Name of SQL schema in database to write to (if database flavor
1547
1546
supports this). If None, use default schema (default).
1547
+ need_transaction : bool, default False
1548
+ If True, SQLDatabase will create a transaction.
1548
1549
1549
1550
"""
1550
1551
1551
- def __init__ (self , con , schema : str | None = None ) -> None :
1552
+ def __init__ (
1553
+ self , con , schema : str | None = None , need_transaction : bool = False
1554
+ ) -> None :
1555
+ from sqlalchemy import create_engine
1556
+ from sqlalchemy .engine import Engine
1552
1557
from sqlalchemy .schema import MetaData
1553
1558
1559
+ self .exit_stack = ExitStack ()
1560
+ if isinstance (con , str ):
1561
+ con = create_engine (con )
1562
+ if isinstance (con , Engine ):
1563
+ con = self .exit_stack .enter_context (con .connect ())
1564
+ if need_transaction :
1565
+ self .exit_stack .enter_context (con .begin ())
1554
1566
self .con = con
1555
1567
self .meta = MetaData (schema = schema )
1568
+ self .returns_generator = False
1569
+
1570
+ def __exit__ (self , * args ) -> None :
1571
+ if not self .returns_generator :
1572
+ self .exit_stack .close ()
1556
1573
1557
1574
@contextmanager
1558
1575
def run_transaction (self ):
@@ -1623,7 +1640,10 @@ def read_table(
1623
1640
"""
1624
1641
self .meta .reflect (bind = self .con , only = [table_name ])
1625
1642
table = SQLTable (table_name , self , index = index_col , schema = schema )
1643
+ if chunksize is not None :
1644
+ self .returns_generator = True
1626
1645
return table .read (
1646
+ self .exit_stack ,
1627
1647
coerce_float = coerce_float ,
1628
1648
parse_dates = parse_dates ,
1629
1649
columns = columns ,
@@ -1733,15 +1753,19 @@ def read_query(
1733
1753
columns = result .keys ()
1734
1754
1735
1755
if chunksize is not None :
1736
- return self ._query_iterator (
1737
- result ,
1738
- chunksize ,
1739
- columns ,
1740
- index_col = index_col ,
1741
- coerce_float = coerce_float ,
1742
- parse_dates = parse_dates ,
1743
- dtype = dtype ,
1744
- use_nullable_dtypes = use_nullable_dtypes ,
1756
+ self .returns_generator = True
1757
+ return _cleanup_after_generator (
1758
+ self ._query_iterator (
1759
+ result ,
1760
+ chunksize ,
1761
+ columns ,
1762
+ index_col = index_col ,
1763
+ coerce_float = coerce_float ,
1764
+ parse_dates = parse_dates ,
1765
+ dtype = dtype ,
1766
+ use_nullable_dtypes = use_nullable_dtypes ,
1767
+ ),
1768
+ self .exit_stack ,
1745
1769
)
1746
1770
else :
1747
1771
data = result .fetchall ()
0 commit comments