7
7
import sqlalchemy
8
8
from sqlalchemy import types , processors , event
9
9
from sqlalchemy .engine import default , Engine
10
- from sqlalchemy .exc import DatabaseError
10
+ from sqlalchemy .exc import DatabaseError , SQLAlchemyError
11
11
from sqlalchemy .engine import reflection
12
12
13
13
from databricks import sql
@@ -154,9 +154,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
154
154
"date" : DatabricksDate ,
155
155
}
156
156
157
- with self .get_driver_connection (
158
- connection
159
- )._dbapi_connection .dbapi_connection .cursor () as cur :
157
+ with self .get_connection_cursor (connection ) as cur :
160
158
resp = cur .columns (
161
159
catalog_name = self .catalog ,
162
160
schema_name = schema or self .schema ,
@@ -245,9 +243,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
245
243
246
244
def get_table_names (self , connection , schema = None , ** kwargs ):
247
245
TABLE_NAME = 1
248
- with self .get_driver_connection (
249
- connection
250
- )._dbapi_connection .dbapi_connection .cursor () as cur :
246
+ with self .get_connection_cursor (connection ) as cur :
251
247
sql_str = "SHOW TABLES FROM {}" .format (
252
248
"." .join ([self .catalog , schema or self .schema ])
253
249
)
@@ -258,9 +254,7 @@ def get_table_names(self, connection, schema=None, **kwargs):
258
254
259
255
def get_view_names (self , connection , schema = None , ** kwargs ):
260
256
VIEW_NAME = 1
261
- with self .get_driver_connection (
262
- connection
263
- )._dbapi_connection .dbapi_connection .cursor () as cur :
257
+ with self .get_connection_cursor (connection ) as cur :
264
258
sql_str = "SHOW VIEWS FROM {}" .format (
265
259
"." .join ([self .catalog , schema or self .schema ])
266
260
)
@@ -292,6 +286,21 @@ def has_table(self, connection, table_name, schema=None, **kwargs) -> bool:
292
286
return False
293
287
else :
294
288
raise e
289
+
290
+
291
+ def get_connection_cursor (self , connection ):
292
+ """Added for backwards compatibility with 1.3.x
293
+ """
294
+ if hasattr (connection , "_dbapi_connection" ):
295
+ return connection ._dbapi_connection .dbapi_connection .cursor ()
296
+ elif hasattr (connection , "raw_connection" ):
297
+ return connection .raw_connection ().cursor ()
298
+ elif hasattr (connection , "connection" ):
299
+ return connection .connection .cursor ()
300
+
301
+ raise SQLAlchemyError ("Databricks dialect can't obtain a cursor context manager from the dbapi" )
302
+
303
+
295
304
296
305
@reflection .cache
297
306
def get_schema_names (self , connection , ** kw ):
0 commit comments