From f7731e21978d6ce79d90677d29440309ed8babf8 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 16 Mar 2021 13:30:53 -0700 Subject: [PATCH] Backport PR #40471: compat: sqlalchemy deprecations --- pandas/io/sql.py | 42 ++++++++++++++++++++++++++++++------- pandas/tests/io/test_sql.py | 25 ++++++++++++++++------ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 02b06b164a2a1..b93f097b93441 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from datetime import date, datetime, time +from distutils.version import LooseVersion from functools import partial import re from typing import Iterator, List, Optional, Union, overload @@ -55,6 +56,16 @@ def _is_sqlalchemy_connectable(con): return False +def _gt14() -> bool: + """ + Check if sqlalchemy.__version__ is at least 1.4.0, when several + deprecations were made. + """ + import sqlalchemy + + return LooseVersion(sqlalchemy.__version__) >= LooseVersion("1.4.0") + + def _convert_params(sql, params): """Convert SQL and params args to DBAPI2.0 compliant format.""" args = [sql] @@ -715,7 +726,10 @@ def sql_schema(self): def _execute_create(self): # Inserting table into database, add to MetaData object - self.table = self.table.tometadata(self.pd_sql.meta) + if _gt14(): + self.table = self.table.to_metadata(self.pd_sql.meta) + else: + self.table = self.table.tometadata(self.pd_sql.meta) self.table.create() def create(self): @@ -1409,9 +1423,17 @@ def to_sql( # Only check when name is not a number and name is not lower case engine = self.connectable.engine with self.connectable.connect() as conn: - table_names = engine.table_names( - schema=schema or self.meta.schema, connection=conn - ) + if _gt14(): + from sqlalchemy import inspect + + insp = inspect(conn) + table_names = insp.get_table_names( + schema=schema or self.meta.schema + ) + else: + table_names = engine.table_names( + schema=schema or self.meta.schema, connection=conn + ) if name not in table_names: msg = ( f"The provided table name '{name}' is not found exactly as " @@ -1426,9 +1448,15 @@ def tables(self): return self.meta.tables def has_table(self, name, schema=None): - return self.connectable.run_callable( - self.connectable.dialect.has_table, name, schema or self.meta.schema - ) + if _gt14(): + import sqlalchemy as sa + + insp = sa.inspect(self.connectable) + return insp.has_table(name, schema or self.meta.schema) + else: + return self.connectable.run_callable( + self.connectable.dialect.has_table, name, schema or self.meta.schema + ) def get_table(self, table_name, schema=None): schema = schema or self.meta.schema diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 16d4bc65094f8..f8a8b662f2652 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -44,10 +44,11 @@ import pandas._testing as tm import pandas.io.sql as sql -from pandas.io.sql import read_sql_query, read_sql_table +from pandas.io.sql import _gt14, read_sql_query, read_sql_table try: import sqlalchemy + from sqlalchemy import inspect from sqlalchemy.ext import declarative from sqlalchemy.orm import session as sa_session import sqlalchemy.schema @@ -1331,7 +1332,11 @@ def test_create_table(self): pandasSQL = sql.SQLDatabase(temp_conn) pandasSQL.to_sql(temp_frame, "temp_frame") - assert temp_conn.has_table("temp_frame") + if _gt14(): + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") + else: + assert temp_conn.has_table("temp_frame") def test_drop_table(self): temp_conn = self.connect() @@ -1343,11 +1348,18 @@ def test_drop_table(self): pandasSQL = sql.SQLDatabase(temp_conn) pandasSQL.to_sql(temp_frame, "temp_frame") - assert temp_conn.has_table("temp_frame") + if _gt14(): + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") + else: + assert temp_conn.has_table("temp_frame") pandasSQL.drop_table("temp_frame") - assert not temp_conn.has_table("temp_frame") + if _gt14(): + assert not insp.has_table("temp_frame") + else: + assert not temp_conn.has_table("temp_frame") def test_roundtrip(self): self._roundtrip() @@ -1689,9 +1701,10 @@ def test_nan_string(self): tm.assert_frame_equal(result, df) def _get_index_columns(self, tbl_name): - from sqlalchemy.engine import reflection + from sqlalchemy import inspect + + insp = inspect(self.conn) - insp = reflection.Inspector.from_engine(self.conn) ixs = insp.get_indexes(tbl_name) ixs = [i["column_names"] for i in ixs] return ixs