Skip to content

Commit c73dc7f

Browse files
authored
Make pandas/io/sql.py work with sqlalchemy 2.0 (#48576)
1 parent 7ffc0ad commit c73dc7f

13 files changed

+188
-93
lines changed

ci/deps/actions-310.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- pyxlsb
4949
- s3fs>=2021.08.0
5050
- scipy
51-
- sqlalchemy<1.4.46
51+
- sqlalchemy
5252
- tabulate
5353
- tzdata>=2022a
5454
- xarray

ci/deps/actions-311.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- pyxlsb
4949
- s3fs>=2021.08.0
5050
- scipy
51-
- sqlalchemy<1.4.46
51+
- sqlalchemy
5252
- tabulate
5353
- tzdata>=2022a
5454
- xarray

ci/deps/actions-38-downstream_compat.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- pyxlsb
4949
- s3fs>=2021.08.0
5050
- scipy
51-
- sqlalchemy<1.4.46
51+
- sqlalchemy
5252
- tabulate
5353
- xarray
5454
- xlrd

ci/deps/actions-38.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- pyxlsb
4949
- s3fs>=2021.08.0
5050
- scipy
51-
- sqlalchemy<1.4.46
51+
- sqlalchemy
5252
- tabulate
5353
- xarray
5454
- xlrd

ci/deps/actions-39.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- pyxlsb
4949
- s3fs>=2021.08.0
5050
- scipy
51-
- sqlalchemy<1.4.46
51+
- sqlalchemy
5252
- tabulate
5353
- tzdata>=2022a
5454
- xarray

ci/deps/circle-38-arm64.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dependencies:
4949
- pyxlsb
5050
- s3fs>=2021.08.0
5151
- scipy
52-
- sqlalchemy<1.4.46
52+
- sqlalchemy
5353
- tabulate
5454
- xarray
5555
- xlrd

doc/source/user_guide/io.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -5868,15 +5868,15 @@ If you have an SQLAlchemy description of your database you can express where con
58685868
sa.Column("Col_3", sa.Boolean),
58695869
)
58705870
5871-
pd.read_sql(sa.select([data_table]).where(data_table.c.Col_3 is True), engine)
5871+
pd.read_sql(sa.select(data_table).where(data_table.c.Col_3 is True), engine)
58725872
58735873
You can combine SQLAlchemy expressions with parameters passed to :func:`read_sql` using :func:`sqlalchemy.bindparam`
58745874

58755875
.. ipython:: python
58765876
58775877
import datetime as dt
58785878
5879-
expr = sa.select([data_table]).where(data_table.c.Date > sa.bindparam("date"))
5879+
expr = sa.select(data_table).where(data_table.c.Date > sa.bindparam("date"))
58805880
pd.read_sql(expr, engine, params={"date": dt.datetime(2010, 10, 18)})
58815881
58825882

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ Other enhancements
294294
- Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`)
295295
- Added :meth:`Series.dt.unit` and :meth:`Series.dt.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`51223`)
296296
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
297+
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
297298
-
298299

299300
.. ---------------------------------------------------------------------------

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ dependencies:
5151
- pyxlsb
5252
- s3fs>=2021.08.0
5353
- scipy
54-
- sqlalchemy<1.4.46
54+
- sqlalchemy
5555
- tabulate
5656
- tzdata>=2022a
5757
- xarray

pandas/core/generic.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -2711,7 +2711,7 @@ def to_sql(
27112711
library. Legacy support is provided for sqlite3.Connection objects. The user
27122712
is responsible for engine disposal and connection closure for the SQLAlchemy
27132713
connectable. See `here \
2714-
<https://docs.sqlalchemy.org/en/14/core/connections.html>`_.
2714+
<https://docs.sqlalchemy.org/en/20/core/connections.html>`_.
27152715
If passing a sqlalchemy.engine.Connection which is already in a transaction,
27162716
the transaction will not be committed. If passing a sqlite3.Connection,
27172717
it will not be possible to roll back the record insertion.
@@ -2761,7 +2761,7 @@ def to_sql(
27612761
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
27622762
reflect the exact number of written rows as stipulated in the
27632763
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
2764-
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
2764+
`SQLAlchemy <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.rowcount>`__.
27652765
27662766
.. versionadded:: 1.4.0
27672767
@@ -2805,7 +2805,9 @@ def to_sql(
28052805
28062806
>>> df.to_sql('users', con=engine)
28072807
3
2808-
>>> engine.execute("SELECT * FROM users").fetchall()
2808+
>>> from sqlalchemy import text
2809+
>>> with engine.connect() as conn:
2810+
... conn.execute(text("SELECT * FROM users")).fetchall()
28092811
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]
28102812
28112813
An `sqlalchemy.engine.Connection` can also be passed to `con`:
@@ -2821,7 +2823,8 @@ def to_sql(
28212823
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
28222824
>>> df2.to_sql('users', con=engine, if_exists='append')
28232825
2
2824-
>>> engine.execute("SELECT * FROM users").fetchall()
2826+
>>> with engine.connect() as conn:
2827+
... conn.execute(text("SELECT * FROM users")).fetchall()
28252828
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
28262829
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
28272830
(1, 'User 7')]
@@ -2831,7 +2834,8 @@ def to_sql(
28312834
>>> df2.to_sql('users', con=engine, if_exists='replace',
28322835
... index_label='id')
28332836
2
2834-
>>> engine.execute("SELECT * FROM users").fetchall()
2837+
>>> with engine.connect() as conn:
2838+
... conn.execute(text("SELECT * FROM users")).fetchall()
28352839
[(0, 'User 6'), (1, 'User 7')]
28362840
28372841
Specify the dtype (especially useful for integers with missing values).
@@ -2851,7 +2855,8 @@ def to_sql(
28512855
... dtype={"A": Integer()})
28522856
3
28532857
2854-
>>> engine.execute("SELECT * FROM integers").fetchall()
2858+
>>> with engine.connect() as conn:
2859+
... conn.execute(text("SELECT * FROM integers")).fetchall()
28552860
[(1,), (None,), (2,)]
28562861
""" # noqa:E501
28572862
from pandas.io import sql

pandas/io/sql.py

+45-39
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,16 @@
6969

7070
if TYPE_CHECKING:
7171
from sqlalchemy import Table
72+
from sqlalchemy.sql.expression import (
73+
Select,
74+
TextClause,
75+
)
7276

7377

7478
# -----------------------------------------------------------------------------
7579
# -- Helper functions
7680

7781

78-
def _convert_params(sql, params):
79-
"""Convert SQL and params args to DBAPI2.0 compliant format."""
80-
args = [sql]
81-
if params is not None:
82-
if hasattr(params, "keys"): # test if params is a mapping
83-
args += [params]
84-
else:
85-
args += [list(params)]
86-
return args
87-
88-
8982
def _process_parse_dates_argument(parse_dates):
9083
"""Process parse_dates argument for read_sql functions"""
9184
# handle non-list entries for parse_dates gracefully
@@ -224,8 +217,7 @@ def execute(sql, con, params=None):
224217
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)):
225218
raise TypeError("pandas.io.sql.execute requires a connection") # GH50185
226219
with pandasSQL_builder(con, need_transaction=True) as pandas_sql:
227-
args = _convert_params(sql, params)
228-
return pandas_sql.execute(*args)
220+
return pandas_sql.execute(sql, params)
229221

230222

231223
# -----------------------------------------------------------------------------
@@ -348,7 +340,7 @@ def read_sql_table(
348340
else using_nullable_dtypes()
349341
)
350342

351-
with pandasSQL_builder(con, schema=schema) as pandas_sql:
343+
with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql:
352344
if not pandas_sql.has_table(table_name):
353345
raise ValueError(f"Table {table_name} not found")
354346

@@ -951,7 +943,8 @@ def sql_schema(self) -> str:
951943
def _execute_create(self) -> None:
952944
# Inserting table into database, add to MetaData object
953945
self.table = self.table.to_metadata(self.pd_sql.meta)
954-
self.table.create(bind=self.pd_sql.con)
946+
with self.pd_sql.run_transaction():
947+
self.table.create(bind=self.pd_sql.con)
955948

956949
def create(self) -> None:
957950
if self.exists():
@@ -1221,7 +1214,7 @@ def _create_table_setup(self):
12211214

12221215
column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)
12231216

1224-
columns = [
1217+
columns: list[Any] = [
12251218
Column(name, typ, index=is_index)
12261219
for name, typ, is_index in column_names_and_types
12271220
]
@@ -1451,7 +1444,7 @@ def to_sql(
14511444
pass
14521445

14531446
@abstractmethod
1454-
def execute(self, *args, **kwargs):
1447+
def execute(self, sql: str | Select | TextClause, params=None):
14551448
pass
14561449

14571450
@abstractmethod
@@ -1511,7 +1504,7 @@ def insert_records(
15111504

15121505
try:
15131506
return table.insert(chunksize=chunksize, method=method)
1514-
except exc.SQLAlchemyError as err:
1507+
except exc.StatementError as err:
15151508
# GH34431
15161509
# https://stackoverflow.com/a/67358288/6067848
15171510
msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?#
@@ -1579,13 +1572,18 @@ def __init__(
15791572
from sqlalchemy.engine import Engine
15801573
from sqlalchemy.schema import MetaData
15811574

1575+
# self.exit_stack cleans up the Engine and Connection and commits the
1576+
# transaction if any of those objects was created below.
1577+
# Cleanup happens either in self.__exit__ or at the end of the iterator
1578+
# returned by read_sql when chunksize is not None.
15821579
self.exit_stack = ExitStack()
15831580
if isinstance(con, str):
15841581
con = create_engine(con)
1582+
self.exit_stack.callback(con.dispose)
15851583
if isinstance(con, Engine):
15861584
con = self.exit_stack.enter_context(con.connect())
1587-
if need_transaction:
1588-
self.exit_stack.enter_context(con.begin())
1585+
if need_transaction and not con.in_transaction():
1586+
self.exit_stack.enter_context(con.begin())
15891587
self.con = con
15901588
self.meta = MetaData(schema=schema)
15911589
self.returns_generator = False
@@ -1596,11 +1594,18 @@ def __exit__(self, *args) -> None:
15961594

15971595
@contextmanager
15981596
def run_transaction(self):
1599-
yield self.con
1597+
if not self.con.in_transaction():
1598+
with self.con.begin():
1599+
yield self.con
1600+
else:
1601+
yield self.con
16001602

1601-
def execute(self, *args, **kwargs):
1603+
def execute(self, sql: str | Select | TextClause, params=None):
16021604
"""Simple passthrough to SQLAlchemy connectable"""
1603-
return self.con.execute(*args, **kwargs)
1605+
args = [] if params is None else [params]
1606+
if isinstance(sql, str):
1607+
return self.con.exec_driver_sql(sql, *args)
1608+
return self.con.execute(sql, *args)
16041609

16051610
def read_table(
16061611
self,
@@ -1780,9 +1785,7 @@ def read_query(
17801785
read_sql
17811786
17821787
"""
1783-
args = _convert_params(sql, params)
1784-
1785-
result = self.execute(*args)
1788+
result = self.execute(sql, params)
17861789
columns = result.keys()
17871790

17881791
if chunksize is not None:
@@ -1838,13 +1841,14 @@ def prep_table(
18381841
else:
18391842
dtype = cast(dict, dtype)
18401843

1841-
from sqlalchemy.types import (
1842-
TypeEngine,
1843-
to_instance,
1844-
)
1844+
from sqlalchemy.types import TypeEngine
18451845

18461846
for col, my_type in dtype.items():
1847-
if not isinstance(to_instance(my_type), TypeEngine):
1847+
if isinstance(my_type, type) and issubclass(my_type, TypeEngine):
1848+
pass
1849+
elif isinstance(my_type, TypeEngine):
1850+
pass
1851+
else:
18481852
raise ValueError(f"The type of {col} is not a SQLAlchemy type")
18491853

18501854
table = SQLTable(
@@ -2005,7 +2009,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
20052009
schema = schema or self.meta.schema
20062010
if self.has_table(table_name, schema):
20072011
self.meta.reflect(bind=self.con, only=[table_name], schema=schema)
2008-
self.get_table(table_name, schema).drop(bind=self.con)
2012+
with self.run_transaction():
2013+
self.get_table(table_name, schema).drop(bind=self.con)
20092014
self.meta.clear()
20102015

20112016
def _create_sql_schema(
@@ -2238,21 +2243,24 @@ def run_transaction(self):
22382243
finally:
22392244
cur.close()
22402245

2241-
def execute(self, *args, **kwargs):
2246+
def execute(self, sql: str | Select | TextClause, params=None):
2247+
if not isinstance(sql, str):
2248+
raise TypeError("Query must be a string unless using sqlalchemy.")
2249+
args = [] if params is None else [params]
22422250
cur = self.con.cursor()
22432251
try:
2244-
cur.execute(*args, **kwargs)
2252+
cur.execute(sql, *args)
22452253
return cur
22462254
except Exception as exc:
22472255
try:
22482256
self.con.rollback()
22492257
except Exception as inner_exc: # pragma: no cover
22502258
ex = DatabaseError(
2251-
f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback"
2259+
f"Execution failed on sql: {sql}\n{exc}\nunable to rollback"
22522260
)
22532261
raise ex from inner_exc
22542262

2255-
ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}")
2263+
ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}")
22562264
raise ex from exc
22572265

22582266
@staticmethod
@@ -2305,9 +2313,7 @@ def read_query(
23052313
dtype: DtypeArg | None = None,
23062314
use_nullable_dtypes: bool = False,
23072315
) -> DataFrame | Iterator[DataFrame]:
2308-
2309-
args = _convert_params(sql, params)
2310-
cursor = self.execute(*args)
2316+
cursor = self.execute(sql, params)
23112317
columns = [col_desc[0] for col_desc in cursor.description]
23122318

23132319
if chunksize is not None:

0 commit comments

Comments
 (0)