9
9
ABC ,
10
10
abstractmethod ,
11
11
)
12
- from contextlib import contextmanager
12
+ from contextlib import (
13
+ ExitStack ,
14
+ contextmanager ,
15
+ )
13
16
from datetime import (
14
17
date ,
15
18
datetime ,
69
72
# -- Helper functions
70
73
71
74
75
+ def _cleanup_after_generator (generator , exit_stack : ExitStack ):
76
+ """Does the cleanup after iterating through the generator."""
77
+ try :
78
+ yield from generator
79
+ finally :
80
+ exit_stack .close ()
81
+
82
+
72
83
def _convert_params (sql , params ):
73
84
"""Convert SQL and params args to DBAPI2.0 compliant format."""
74
85
args = [sql ]
@@ -792,12 +803,11 @@ def has_table(table_name: str, con, schema: str | None = None) -> bool:
792
803
table_exists = has_table
793
804
794
805
795
- @contextmanager
796
806
def pandasSQL_builder (
797
807
con ,
798
808
schema : str | None = None ,
799
809
need_transaction : bool = False ,
800
- ) -> Iterator [ PandasSQL ] :
810
+ ) -> PandasSQL :
801
811
"""
802
812
Convenience function to return the correct PandasSQL subclass based on the
803
813
provided parameters. Also creates a sqlalchemy connection and transaction
@@ -806,45 +816,24 @@ def pandasSQL_builder(
806
816
import sqlite3
807
817
808
818
if isinstance (con , sqlite3 .Connection ) or con is None :
809
- yield SQLiteDatabase (con )
810
- else :
811
- sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "ignore" )
819
+ return SQLiteDatabase (con )
812
820
813
- if sqlalchemy is not None and isinstance (
814
- con , (str , sqlalchemy .engine .Connectable )
815
- ):
816
- with _sqlalchemy_con (con , need_transaction ) as con :
817
- yield SQLDatabase (con , schema = schema )
818
- elif isinstance (con , str ) and sqlalchemy is None :
819
- raise ImportError ("Using URI string without sqlalchemy installed." )
820
- else :
821
-
822
- warnings .warn (
823
- "pandas only supports SQLAlchemy connectable (engine/connection) or "
824
- "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
825
- "objects are not tested. Please consider using SQLAlchemy." ,
826
- UserWarning ,
827
- stacklevel = find_stack_level () + 2 ,
828
- )
829
- yield SQLiteDatabase (con )
821
+ sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "ignore" )
830
822
823
+ if isinstance (con , str ) and sqlalchemy is None :
824
+ raise ImportError ("Using URI string without sqlalchemy installed." )
831
825
832
- @contextmanager
833
- def _sqlalchemy_con (connectable , need_transaction : bool ):
834
- """Create a sqlalchemy connection and a transaction if necessary."""
835
- sqlalchemy = import_optional_dependency ("sqlalchemy" , errors = "raise" )
826
+ if sqlalchemy is not None and isinstance (con , (str , sqlalchemy .engine .Connectable )):
827
+ return SQLDatabase (con , schema , need_transaction )
836
828
837
- if isinstance (connectable , str ):
838
- connectable = sqlalchemy .create_engine (connectable )
839
- if isinstance (connectable , sqlalchemy .engine .Engine ):
840
- with connectable .connect () as con :
841
- if need_transaction :
842
- with con .begin ():
843
- yield con
844
- else :
845
- yield con
846
- else :
847
- yield connectable
829
+ warnings .warn (
830
+ "pandas only supports SQLAlchemy connectable (engine/connection) or "
831
+ "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 "
832
+ "objects are not tested. Please consider using SQLAlchemy." ,
833
+ UserWarning ,
834
+ stacklevel = find_stack_level (),
835
+ )
836
+ return SQLiteDatabase (con )
848
837
849
838
850
839
class SQLTable (PandasObject ):
@@ -1069,6 +1058,7 @@ def _query_iterator(
1069
1058
1070
1059
def read (
1071
1060
self ,
1061
+ exit_stack : ExitStack ,
1072
1062
coerce_float : bool = True ,
1073
1063
parse_dates = None ,
1074
1064
columns = None ,
@@ -1089,13 +1079,16 @@ def read(
1089
1079
column_names = result .keys ()
1090
1080
1091
1081
if chunksize is not None :
1092
- return self ._query_iterator (
1093
- result ,
1094
- chunksize ,
1095
- column_names ,
1096
- coerce_float = coerce_float ,
1097
- parse_dates = parse_dates ,
1098
- use_nullable_dtypes = use_nullable_dtypes ,
1082
+ return _cleanup_after_generator (
1083
+ self ._query_iterator (
1084
+ result ,
1085
+ chunksize ,
1086
+ column_names ,
1087
+ coerce_float = coerce_float ,
1088
+ parse_dates = parse_dates ,
1089
+ use_nullable_dtypes = use_nullable_dtypes ,
1090
+ ),
1091
+ exit_stack ,
1099
1092
)
1100
1093
else :
1101
1094
data = result .fetchall ()
@@ -1347,6 +1340,12 @@ class PandasSQL(PandasObject, ABC):
1347
1340
Subclasses Should define read_query and to_sql.
1348
1341
"""
1349
1342
1343
+ def __enter__ (self ):
1344
+ return self
1345
+
1346
+ def __exit__ (self , * args ) -> None :
1347
+ pass
1348
+
1350
1349
def read_table (
1351
1350
self ,
1352
1351
table_name : str ,
@@ -1502,20 +1501,38 @@ class SQLDatabase(PandasSQL):
1502
1501
1503
1502
Parameters
1504
1503
----------
1505
- con : SQLAlchemy Connection
1506
- Connection to connect with the database. Using SQLAlchemy makes it
1504
+ con : SQLAlchemy Connectable or URI string.
1505
+ Connectable to connect with the database. Using SQLAlchemy makes it
1507
1506
possible to use any DB supported by that library.
1508
1507
schema : string, default None
1509
1508
Name of SQL schema in database to write to (if database flavor
1510
1509
supports this). If None, use default schema (default).
1510
+ need_transaction : bool, default False
1511
+ If True, SQLDatabase will create a transaction.
1511
1512
1512
1513
"""
1513
1514
1514
- def __init__ (self , con , schema : str | None = None ) -> None :
1515
+ def __init__ (
1516
+ self , con , schema : str | None = None , need_transaction : bool = False
1517
+ ) -> None :
1518
+ from sqlalchemy import create_engine
1519
+ from sqlalchemy .engine import Engine
1515
1520
from sqlalchemy .schema import MetaData
1516
1521
1522
+ self .exit_stack = ExitStack ()
1523
+ if isinstance (con , str ):
1524
+ con = create_engine (con )
1525
+ if isinstance (con , Engine ):
1526
+ con = self .exit_stack .enter_context (con .connect ())
1527
+ if need_transaction :
1528
+ self .exit_stack .enter_context (con .begin ())
1517
1529
self .con = con
1518
1530
self .meta = MetaData (schema = schema )
1531
+ self .returns_generator = False
1532
+
1533
+ def __exit__ (self , * args ) -> None :
1534
+ if not self .returns_generator :
1535
+ self .exit_stack .close ()
1519
1536
1520
1537
@contextmanager
1521
1538
def run_transaction (self ):
@@ -1586,7 +1603,10 @@ def read_table(
1586
1603
"""
1587
1604
self .meta .reflect (bind = self .con , only = [table_name ])
1588
1605
table = SQLTable (table_name , self , index = index_col , schema = schema )
1606
+ if chunksize is not None :
1607
+ self .returns_generator = True
1589
1608
return table .read (
1609
+ self .exit_stack ,
1590
1610
coerce_float = coerce_float ,
1591
1611
parse_dates = parse_dates ,
1592
1612
columns = columns ,
@@ -1696,15 +1716,19 @@ def read_query(
1696
1716
columns = result .keys ()
1697
1717
1698
1718
if chunksize is not None :
1699
- return self ._query_iterator (
1700
- result ,
1701
- chunksize ,
1702
- columns ,
1703
- index_col = index_col ,
1704
- coerce_float = coerce_float ,
1705
- parse_dates = parse_dates ,
1706
- dtype = dtype ,
1707
- use_nullable_dtypes = use_nullable_dtypes ,
1719
+ self .returns_generator = True
1720
+ return _cleanup_after_generator (
1721
+ self ._query_iterator (
1722
+ result ,
1723
+ chunksize ,
1724
+ columns ,
1725
+ index_col = index_col ,
1726
+ coerce_float = coerce_float ,
1727
+ parse_dates = parse_dates ,
1728
+ dtype = dtype ,
1729
+ use_nullable_dtypes = use_nullable_dtypes ,
1730
+ ),
1731
+ self .exit_stack ,
1708
1732
)
1709
1733
else :
1710
1734
data = result .fetchall ()
0 commit comments