Skip to content

Make pandas/io/sql.py work with sqlalchemy 2.0 #48576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 9, 2023
2 changes: 1 addition & 1 deletion ci/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if [[ $(uname) == "Linux" && -z $DISPLAY ]]; then
XVFB="xvfb-run "
fi

PYTEST_CMD="${XVFB}pytest -r fEs -n $PYTEST_WORKERS --dist=loadfile $TEST_ARGS $COVERAGE $PYTEST_TARGET"
PYTEST_CMD="SQLALCHEMY_WARN_20=1 ${XVFB}pytest -r fEs -n $PYTEST_WORKERS --dist=loadfile $TEST_ARGS $COVERAGE $PYTEST_TARGET"

if [[ "$PATTERN" ]]; then
PYTEST_CMD="$PYTEST_CMD -m \"$PATTERN\""
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ Other enhancements
- Improved error message when trying to align :class:`DataFrame` objects (for example, in :func:`DataFrame.compare`) to clarify that "identically labelled" refers to both index and columns (:issue:`50083`)
- Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`)
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
-

.. ---------------------------------------------------------------------------
Expand Down
17 changes: 11 additions & 6 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,7 +2709,7 @@ def to_sql(
library. Legacy support is provided for sqlite3.Connection objects. The user
is responsible for engine disposal and connection closure for the SQLAlchemy
connectable. See `here \
<https://docs.sqlalchemy.org/en/14/core/connections.html>`_.
<https://docs.sqlalchemy.org/en/20/core/connections.html>`_.
If passing a sqlalchemy.engine.Connection which is already in a transaction,
the transaction will not be committed. If passing a sqlite3.Connection,
it will not be possible to roll back the record insertion.
Expand Down Expand Up @@ -2759,7 +2759,7 @@ def to_sql(
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
reflect the exact number of written rows as stipulated in the
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
`SQLAlchemy <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.rowcount>`__.

.. versionadded:: 1.4.0

Expand Down Expand Up @@ -2803,7 +2803,9 @@ def to_sql(

>>> df.to_sql('users', con=engine)
3
>>> engine.execute("SELECT * FROM users").fetchall()
>>> from sqlalchemy import text
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]

An `sqlalchemy.engine.Connection` can also be passed to `con`:
Expand All @@ -2819,7 +2821,8 @@ def to_sql(
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
>>> df2.to_sql('users', con=engine, if_exists='append')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
(1, 'User 7')]
Expand All @@ -2829,7 +2832,8 @@ def to_sql(
>>> df2.to_sql('users', con=engine, if_exists='replace',
... index_label='id')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 6'), (1, 'User 7')]

Specify the dtype (especially useful for integers with missing values).
Expand All @@ -2849,7 +2853,8 @@ def to_sql(
... dtype={"A": Integer()})
3

>>> engine.execute("SELECT * FROM integers").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM integers")).fetchall()
[(1,), (None,), (2,)]
""" # noqa:E501
from pandas.io import sql
Expand Down
80 changes: 41 additions & 39 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,16 @@

if TYPE_CHECKING:
from sqlalchemy import Table
from sqlalchemy.sql.expression import (
Select,
TextClause,
)


# -----------------------------------------------------------------------------
# -- Helper functions


def _convert_params(sql, params):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
if params is not None:
if hasattr(params, "keys"): # test if params is a mapping
args += [params]
else:
args += [list(params)]
return args


def _process_parse_dates_argument(parse_dates):
"""Process parse_dates argument for read_sql functions"""
# handle non-list entries for parse_dates gracefully
Expand Down Expand Up @@ -224,8 +217,7 @@ def execute(sql, con, params=None):
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)):
raise TypeError("pandas.io.sql.execute requires a connection") # GH50185
with pandasSQL_builder(con, need_transaction=True) as pandas_sql:
args = _convert_params(sql, params)
return pandas_sql.execute(*args)
return pandas_sql.execute(sql, params)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -348,7 +340,7 @@ def read_sql_table(
else using_nullable_dtypes()
)

with pandasSQL_builder(con, schema=schema) as pandas_sql:
with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql:
if not pandas_sql.has_table(table_name):
raise ValueError(f"Table {table_name} not found")

Expand Down Expand Up @@ -951,7 +943,8 @@ def sql_schema(self) -> str:
def _execute_create(self) -> None:
# Inserting table into database, add to MetaData object
self.table = self.table.to_metadata(self.pd_sql.meta)
self.table.create(bind=self.pd_sql.con)
with self.pd_sql.run_transaction():
self.table.create(bind=self.pd_sql.con)

def create(self) -> None:
if self.exists():
Expand Down Expand Up @@ -1221,7 +1214,7 @@ def _create_table_setup(self):

column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)

columns = [
columns: list[Any] = [
Column(name, typ, index=is_index)
for name, typ, is_index in column_names_and_types
]
Expand Down Expand Up @@ -1451,7 +1444,7 @@ def to_sql(
pass

@abstractmethod
def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
pass

@abstractmethod
Expand Down Expand Up @@ -1511,7 +1504,7 @@ def insert_records(

try:
return table.insert(chunksize=chunksize, method=method)
except exc.SQLAlchemyError as err:
except exc.StatementError as err:
# GH34431
# https://stackoverflow.com/a/67358288/6067848
msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?#
Expand Down Expand Up @@ -1582,10 +1575,11 @@ def __init__(
self.exit_stack = ExitStack()
if isinstance(con, str):
con = create_engine(con)
self.exit_stack.callback(con.dispose)
if isinstance(con, Engine):
con = self.exit_stack.enter_context(con.connect())
if need_transaction:
self.exit_stack.enter_context(con.begin())
if need_transaction and not con.in_transaction():
self.exit_stack.enter_context(con.begin())
self.con = con
self.meta = MetaData(schema=schema)
self.returns_generator = False
Expand All @@ -1596,11 +1590,18 @@ def __exit__(self, *args) -> None:

@contextmanager
def run_transaction(self):
yield self.con
if not self.con.in_transaction():
with self.con.begin():
yield self.con
else:
yield self.con

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
"""Simple passthrough to SQLAlchemy connectable"""
return self.con.execute(*args, **kwargs)
args = [] if params is None else [params]
if isinstance(sql, str):
return self.con.exec_driver_sql(sql, *args)
return self.con.execute(sql, *args)

def read_table(
self,
Expand Down Expand Up @@ -1780,9 +1781,7 @@ def read_query(
read_sql

"""
args = _convert_params(sql, params)

result = self.execute(*args)
result = self.execute(sql, params)
columns = result.keys()

if chunksize is not None:
Expand Down Expand Up @@ -1838,13 +1837,14 @@ def prep_table(
else:
dtype = cast(dict, dtype)

from sqlalchemy.types import (
TypeEngine,
to_instance,
)
from sqlalchemy.types import TypeEngine

for col, my_type in dtype.items():
if not isinstance(to_instance(my_type), TypeEngine):
if isinstance(my_type, type) and issubclass(my_type, TypeEngine):
pass
elif isinstance(my_type, TypeEngine):
pass
else:
raise ValueError(f"The type of {col} is not a SQLAlchemy type")

table = SQLTable(
Expand Down Expand Up @@ -2005,7 +2005,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
schema = schema or self.meta.schema
if self.has_table(table_name, schema):
self.meta.reflect(bind=self.con, only=[table_name], schema=schema)
self.get_table(table_name, schema).drop(bind=self.con)
with self.run_transaction():
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def _create_sql_schema(
Expand Down Expand Up @@ -2238,21 +2239,24 @@ def run_transaction(self):
finally:
cur.close()

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
if not isinstance(sql, str):
raise TypeError("Query must be a string unless using sqlalchemy.")
args = [] if params is None else [params]
cur = self.con.cursor()
try:
cur.execute(*args, **kwargs)
cur.execute(sql, *args)
return cur
except Exception as exc:
try:
self.con.rollback()
except Exception as inner_exc: # pragma: no cover
ex = DatabaseError(
f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback"
f"Execution failed on sql: {sql}\n{exc}\nunable to rollback"
)
raise ex from inner_exc

ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}")
ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}")
raise ex from exc

@staticmethod
Expand Down Expand Up @@ -2305,9 +2309,7 @@ def read_query(
dtype: DtypeArg | None = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | Iterator[DataFrame]:

args = _convert_params(sql, params)
cursor = self.execute(*args)
cursor = self.execute(sql, params)
columns = [col_desc[0] for col_desc in cursor.description]

if chunksize is not None:
Expand Down
Loading