Skip to content

Commit 52f3c5e

Browse files
author
Jesse Whitehouse
committed
(3/x) SQLAlchemy 1.3 exposes the underlying connection / cursor with
different property names than SQLAlchemy > 1.4. So here I wrapped the logic in a public method. Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent b5c1dd9 commit 52f3c5e

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

src/databricks/sqlalchemy/dialect/__init__.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sqlalchemy
88
from sqlalchemy import types, processors, event
99
from sqlalchemy.engine import default, Engine
10-
from sqlalchemy.exc import DatabaseError
10+
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
1111
from sqlalchemy.engine import reflection
1212

1313
from databricks import sql
@@ -154,9 +154,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
154154
"date": DatabricksDate,
155155
}
156156

157-
with self.get_driver_connection(
158-
connection
159-
)._dbapi_connection.dbapi_connection.cursor() as cur:
157+
with self.get_connection_cursor(connection) as cur:
160158
resp = cur.columns(
161159
catalog_name=self.catalog,
162160
schema_name=schema or self.schema,
@@ -245,9 +243,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
245243

246244
def get_table_names(self, connection, schema=None, **kwargs):
247245
TABLE_NAME = 1
248-
with self.get_driver_connection(
249-
connection
250-
)._dbapi_connection.dbapi_connection.cursor() as cur:
246+
with self.get_connection_cursor(connection) as cur:
251247
sql_str = "SHOW TABLES FROM {}".format(
252248
".".join([self.catalog, schema or self.schema])
253249
)
@@ -258,9 +254,7 @@ def get_table_names(self, connection, schema=None, **kwargs):
258254

259255
def get_view_names(self, connection, schema=None, **kwargs):
260256
VIEW_NAME = 1
261-
with self.get_driver_connection(
262-
connection
263-
)._dbapi_connection.dbapi_connection.cursor() as cur:
257+
with self.get_connection_cursor(connection) as cur:
264258
sql_str = "SHOW VIEWS FROM {}".format(
265259
".".join([self.catalog, schema or self.schema])
266260
)
@@ -292,6 +286,21 @@ def has_table(self, connection, table_name, schema=None, **kwargs) -> bool:
292286
return False
293287
else:
294288
raise e
289+
290+
291+
def get_connection_cursor(self, connection):
292+
"""Added for backwards compatibility with 1.3.x
293+
"""
294+
if hasattr(connection, "_dbapi_connection"):
295+
return connection._dbapi_connection.dbapi_connection.cursor()
296+
elif hasattr(connection, "raw_connection"):
297+
return connection.raw_connection().cursor()
298+
elif hasattr(connection, "connection"):
299+
return connection.connection.cursor()
300+
301+
raise SQLAlchemyError("Databricks dialect can't obtain a cursor context manager from the dbapi")
302+
303+
295304

296305
@reflection.cache
297306
def get_schema_names(self, connection, **kw):

0 commit comments

Comments
 (0)