69
69
70
70
if TYPE_CHECKING :
71
71
from sqlalchemy import Table
72
+ from sqlalchemy .sql .expression import (
73
+ Select ,
74
+ TextClause ,
75
+ )
72
76
73
77
74
78
# -----------------------------------------------------------------------------
75
79
# -- Helper functions
76
80
77
81
78
- def _convert_params (sql , params ):
79
- """Convert SQL and params args to DBAPI2.0 compliant format."""
80
- args = [sql ]
81
- if params is not None :
82
- if hasattr (params , "keys" ): # test if params is a mapping
83
- args += [params ]
84
- else :
85
- args += [list (params )]
86
- return args
87
-
88
-
89
82
def _process_parse_dates_argument (parse_dates ):
90
83
"""Process parse_dates argument for read_sql functions"""
91
84
# handle non-list entries for parse_dates gracefully
@@ -224,8 +217,7 @@ def execute(sql, con, params=None):
224
217
if sqlalchemy is not None and isinstance (con , (str , sqlalchemy .engine .Engine )):
225
218
raise TypeError ("pandas.io.sql.execute requires a connection" ) # GH50185
226
219
with pandasSQL_builder (con , need_transaction = True ) as pandas_sql :
227
- args = _convert_params (sql , params )
228
- return pandas_sql .execute (* args )
220
+ return pandas_sql .execute (sql , params )
229
221
230
222
231
223
# -----------------------------------------------------------------------------
@@ -348,7 +340,7 @@ def read_sql_table(
348
340
else using_nullable_dtypes ()
349
341
)
350
342
351
- with pandasSQL_builder (con , schema = schema ) as pandas_sql :
343
+ with pandasSQL_builder (con , schema = schema , need_transaction = True ) as pandas_sql :
352
344
if not pandas_sql .has_table (table_name ):
353
345
raise ValueError (f"Table { table_name } not found" )
354
346
@@ -951,7 +943,8 @@ def sql_schema(self) -> str:
951
943
def _execute_create (self ) -> None :
952
944
# Inserting table into database, add to MetaData object
953
945
self .table = self .table .to_metadata (self .pd_sql .meta )
954
- self .table .create (bind = self .pd_sql .con )
946
+ with self .pd_sql .run_transaction ():
947
+ self .table .create (bind = self .pd_sql .con )
955
948
956
949
def create (self ) -> None :
957
950
if self .exists ():
@@ -1221,7 +1214,7 @@ def _create_table_setup(self):
1221
1214
1222
1215
column_names_and_types = self ._get_column_names_and_types (self ._sqlalchemy_type )
1223
1216
1224
- columns = [
1217
+ columns : list [ Any ] = [
1225
1218
Column (name , typ , index = is_index )
1226
1219
for name , typ , is_index in column_names_and_types
1227
1220
]
@@ -1451,7 +1444,7 @@ def to_sql(
1451
1444
pass
1452
1445
1453
1446
@abstractmethod
1454
- def execute (self , * args , ** kwargs ):
1447
+ def execute (self , sql : str | Select | TextClause , params = None ):
1455
1448
pass
1456
1449
1457
1450
@abstractmethod
@@ -1511,7 +1504,7 @@ def insert_records(
1511
1504
1512
1505
try :
1513
1506
return table .insert (chunksize = chunksize , method = method )
1514
- except exc .SQLAlchemyError as err :
1507
+ except exc .StatementError as err :
1515
1508
# GH34431
1516
1509
# https://stackoverflow.com/a/67358288/6067848
1517
1510
msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?#
@@ -1579,13 +1572,18 @@ def __init__(
1579
1572
from sqlalchemy .engine import Engine
1580
1573
from sqlalchemy .schema import MetaData
1581
1574
1575
+ # self.exit_stack cleans up the Engine and Connection and commits the
1576
+ # transaction if any of those objects was created below.
1577
+ # Cleanup happens either in self.__exit__ or at the end of the iterator
1578
+ # returned by read_sql when chunksize is not None.
1582
1579
self .exit_stack = ExitStack ()
1583
1580
if isinstance (con , str ):
1584
1581
con = create_engine (con )
1582
+ self .exit_stack .callback (con .dispose )
1585
1583
if isinstance (con , Engine ):
1586
1584
con = self .exit_stack .enter_context (con .connect ())
1587
- if need_transaction :
1588
- self .exit_stack .enter_context (con .begin ())
1585
+ if need_transaction and not con . in_transaction () :
1586
+ self .exit_stack .enter_context (con .begin ())
1589
1587
self .con = con
1590
1588
self .meta = MetaData (schema = schema )
1591
1589
self .returns_generator = False
@@ -1596,11 +1594,18 @@ def __exit__(self, *args) -> None:
1596
1594
1597
1595
@contextmanager
1598
1596
def run_transaction (self ):
1599
- yield self .con
1597
+ if not self .con .in_transaction ():
1598
+ with self .con .begin ():
1599
+ yield self .con
1600
+ else :
1601
+ yield self .con
1600
1602
1601
- def execute (self , * args , ** kwargs ):
1603
+ def execute (self , sql : str | Select | TextClause , params = None ):
1602
1604
"""Simple passthrough to SQLAlchemy connectable"""
1603
- return self .con .execute (* args , ** kwargs )
1605
+ args = [] if params is None else [params ]
1606
+ if isinstance (sql , str ):
1607
+ return self .con .exec_driver_sql (sql , * args )
1608
+ return self .con .execute (sql , * args )
1604
1609
1605
1610
def read_table (
1606
1611
self ,
@@ -1780,9 +1785,7 @@ def read_query(
1780
1785
read_sql
1781
1786
1782
1787
"""
1783
- args = _convert_params (sql , params )
1784
-
1785
- result = self .execute (* args )
1788
+ result = self .execute (sql , params )
1786
1789
columns = result .keys ()
1787
1790
1788
1791
if chunksize is not None :
@@ -1838,13 +1841,14 @@ def prep_table(
1838
1841
else :
1839
1842
dtype = cast (dict , dtype )
1840
1843
1841
- from sqlalchemy .types import (
1842
- TypeEngine ,
1843
- to_instance ,
1844
- )
1844
+ from sqlalchemy .types import TypeEngine
1845
1845
1846
1846
for col , my_type in dtype .items ():
1847
- if not isinstance (to_instance (my_type ), TypeEngine ):
1847
+ if isinstance (my_type , type ) and issubclass (my_type , TypeEngine ):
1848
+ pass
1849
+ elif isinstance (my_type , TypeEngine ):
1850
+ pass
1851
+ else :
1848
1852
raise ValueError (f"The type of { col } is not a SQLAlchemy type" )
1849
1853
1850
1854
table = SQLTable (
@@ -2005,7 +2009,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
2005
2009
schema = schema or self .meta .schema
2006
2010
if self .has_table (table_name , schema ):
2007
2011
self .meta .reflect (bind = self .con , only = [table_name ], schema = schema )
2008
- self .get_table (table_name , schema ).drop (bind = self .con )
2012
+ with self .run_transaction ():
2013
+ self .get_table (table_name , schema ).drop (bind = self .con )
2009
2014
self .meta .clear ()
2010
2015
2011
2016
def _create_sql_schema (
@@ -2238,21 +2243,24 @@ def run_transaction(self):
2238
2243
finally :
2239
2244
cur .close ()
2240
2245
2241
- def execute (self , * args , ** kwargs ):
2246
+ def execute (self , sql : str | Select | TextClause , params = None ):
2247
+ if not isinstance (sql , str ):
2248
+ raise TypeError ("Query must be a string unless using sqlalchemy." )
2249
+ args = [] if params is None else [params ]
2242
2250
cur = self .con .cursor ()
2243
2251
try :
2244
- cur .execute (* args , ** kwargs )
2252
+ cur .execute (sql , * args )
2245
2253
return cur
2246
2254
except Exception as exc :
2247
2255
try :
2248
2256
self .con .rollback ()
2249
2257
except Exception as inner_exc : # pragma: no cover
2250
2258
ex = DatabaseError (
2251
- f"Execution failed on sql: { args [ 0 ] } \n { exc } \n unable to rollback"
2259
+ f"Execution failed on sql: { sql } \n { exc } \n unable to rollback"
2252
2260
)
2253
2261
raise ex from inner_exc
2254
2262
2255
- ex = DatabaseError (f"Execution failed on sql '{ args [ 0 ] } ': { exc } " )
2263
+ ex = DatabaseError (f"Execution failed on sql '{ sql } ': { exc } " )
2256
2264
raise ex from exc
2257
2265
2258
2266
@staticmethod
@@ -2305,9 +2313,7 @@ def read_query(
2305
2313
dtype : DtypeArg | None = None ,
2306
2314
use_nullable_dtypes : bool = False ,
2307
2315
) -> DataFrame | Iterator [DataFrame ]:
2308
-
2309
- args = _convert_params (sql , params )
2310
- cursor = self .execute (* args )
2316
+ cursor = self .execute (sql , params )
2311
2317
columns = [col_desc [0 ] for col_desc in cursor .description ]
2312
2318
2313
2319
if chunksize is not None :
0 commit comments