Skip to content

Commit f64903c

Browse files
author
Jesse
committed
Improve sqlalchemy backward compatibility with 1.3.24 (databricks#173)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 39d9469 commit f64903c

File tree

4 files changed

+166
-38
lines changed

4 files changed

+166
-38
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 2.7.x (Unreleased)
44

55
- Add support for Cloud Fetch
6+
- Fix: Revised SQLAlchemy dialect and examples for compatibility with SQLAlchemy==1.3.x
67

78
## 2.7.0 (2023-06-26)
89

examples/sqlalchemy.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@
4242
"""
4343

4444
import os
45-
from sqlalchemy.orm import declarative_base, Session
45+
import sqlalchemy
46+
from sqlalchemy.orm import Session
4647
from sqlalchemy import Column, String, Integer, BOOLEAN, create_engine, select
4748

49+
try:
50+
from sqlalchemy.orm import declarative_base
51+
except ImportError:
52+
from sqlalchemy.ext.declarative import declarative_base
53+
4854
host = os.getenv("DATABRICKS_SERVER_HOSTNAME")
4955
http_path = os.getenv("DATABRICKS_HTTP_PATH")
5056
access_token = os.getenv("DATABRICKS_TOKEN")
@@ -59,10 +65,20 @@
5965
"_user_agent_entry": "PySQL Example Script",
6066
}
6167

62-
engine = create_engine(
63-
f"databricks://token:{access_token}@{host}?http_path={http_path}&catalog={catalog}&schema={schema}",
64-
connect_args=extra_connect_args,
65-
)
68+
if sqlalchemy.__version__.startswith("1.3"):
69+
# SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string
70+
# Pass these in as connect_args instead
71+
72+
conn_string = f"databricks://token:{access_token}@{host}"
73+
connect_args = dict(catalog=catalog, schema=schema, http_path=http_path)
74+
all_connect_args = {**extra_connect_args, **connect_args}
75+
engine = create_engine(conn_string, connect_args=all_connect_args)
76+
else:
77+
engine = create_engine(
78+
f"databricks://token:{access_token}@{host}?http_path={http_path}&catalog={catalog}&schema={schema}",
79+
connect_args=extra_connect_args,
80+
)
81+
6682
session = Session(bind=engine)
6783
base = declarative_base(bind=engine)
6884

@@ -73,7 +89,7 @@ class SampleObject(base):
7389

7490
name = Column(String(255), primary_key=True)
7591
episodes = Column(Integer)
76-
some_bool = Column(BOOLEAN)
92+
some_bool = Column(BOOLEAN(create_constraint=False))
7793

7894

7995
base.metadata.create_all()
@@ -86,9 +102,14 @@ class SampleObject(base):
86102

87103
session.commit()
88104

89-
stmt = select(SampleObject).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
105+
# SQLAlchemy 1.3 has slightly different methods
106+
if sqlalchemy.__version__.startswith("1.3"):
107+
stmt = select([SampleObject]).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
108+
output = [i for i in session.execute(stmt)]
109+
else:
110+
stmt = select(SampleObject).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
111+
output = [i for i in session.scalars(stmt)]
90112

91-
output = [i for i in session.scalars(stmt)]
92113
assert len(output) == 2
93114

94115
base.metadata.drop_all()

src/databricks/sqlalchemy/dialect/__init__.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import decimal, re, datetime
55
from dateutil.parser import parse
66

7+
import sqlalchemy
78
from sqlalchemy import types, processors, event
89
from sqlalchemy.engine import default, Engine
9-
from sqlalchemy.exc import DatabaseError
10+
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
1011
from sqlalchemy.engine import reflection
1112

1213
from databricks import sql
@@ -153,9 +154,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
153154
"date": DatabricksDate,
154155
}
155156

156-
with self.get_driver_connection(
157-
connection
158-
)._dbapi_connection.dbapi_connection.cursor() as cur:
157+
with self.get_connection_cursor(connection) as cur:
159158
resp = cur.columns(
160159
catalog_name=self.catalog,
161160
schema_name=schema or self.schema,
@@ -244,9 +243,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
244243

245244
def get_table_names(self, connection, schema=None, **kwargs):
246245
TABLE_NAME = 1
247-
with self.get_driver_connection(
248-
connection
249-
)._dbapi_connection.dbapi_connection.cursor() as cur:
246+
with self.get_connection_cursor(connection) as cur:
250247
sql_str = "SHOW TABLES FROM {}".format(
251248
".".join([self.catalog, schema or self.schema])
252249
)
@@ -257,9 +254,7 @@ def get_table_names(self, connection, schema=None, **kwargs):
257254

258255
def get_view_names(self, connection, schema=None, **kwargs):
259256
VIEW_NAME = 1
260-
with self.get_driver_connection(
261-
connection
262-
)._dbapi_connection.dbapi_connection.cursor() as cur:
257+
with self.get_connection_cursor(connection) as cur:
263258
sql_str = "SHOW VIEWS FROM {}".format(
264259
".".join([self.catalog, schema or self.schema])
265260
)
@@ -292,6 +287,19 @@ def has_table(self, connection, table_name, schema=None, **kwargs) -> bool:
292287
else:
293288
raise e
294289

290+
def get_connection_cursor(self, connection):
291+
"""Added for backwards compatibility with 1.3.x"""
292+
if hasattr(connection, "_dbapi_connection"):
293+
return connection._dbapi_connection.dbapi_connection.cursor()
294+
elif hasattr(connection, "raw_connection"):
295+
return connection.raw_connection().cursor()
296+
elif hasattr(connection, "connection"):
297+
return connection.connection.cursor()
298+
299+
raise SQLAlchemyError(
300+
"Databricks dialect can't obtain a cursor context manager from the dbapi"
301+
)
302+
295303
@reflection.cache
296304
def get_schema_names(self, connection, **kw):
297305
# Equivalent to SHOW DATABASES
@@ -314,3 +322,13 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams):
314322
new_user_agent = "sqlalchemy"
315323

316324
cparams["_user_agent_entry"] = new_user_agent
325+
326+
if sqlalchemy.__version__.startswith("1.3"):
327+
# SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string
328+
# These should be passed in as connect_args when building the Engine
329+
330+
if "schema" in cparams:
331+
dialect.schema = cparams["schema"]
332+
333+
if "catalog" in cparams:
334+
dialect.catalog = cparams["catalog"]

tests/e2e/sqlalchemy/test_basic.py

+108-20
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,81 @@
22
import pytest
33
from unittest import skipIf
44
from sqlalchemy import create_engine, select, insert, Column, MetaData, Table
5-
from sqlalchemy.orm import declarative_base, Session
5+
from sqlalchemy.orm import Session
66
from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String, DECIMAL, Date
7+
from sqlalchemy.engine import Engine
8+
9+
from typing import Tuple
10+
11+
try:
12+
from sqlalchemy.orm import declarative_base
13+
except ImportError:
14+
from sqlalchemy.ext.declarative import declarative_base
715

816

917
USER_AGENT_TOKEN = "PySQL e2e Tests"
1018

1119

12-
@pytest.fixture
13-
def db_engine():
20+
def sqlalchemy_1_3():
21+
import sqlalchemy
22+
23+
return sqlalchemy.__version__.startswith("1.3")
24+
25+
26+
def version_agnostic_select(object_to_select, *args, **kwargs):
27+
"""
28+
SQLAlchemy==1.3.x requires arguments to select() to be a Python list
29+
30+
https://docs.sqlalchemy.org/en/20/changelog/migration_14.html#orm-query-is-internally-unified-with-select-update-delete-2-0-style-execution-available
31+
"""
32+
33+
if sqlalchemy_1_3():
34+
return select([object_to_select], *args, **kwargs)
35+
else:
36+
return select(object_to_select, *args, **kwargs)
37+
38+
39+
def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str, dict]:
1440

1541
HOST = os.environ.get("host")
1642
HTTP_PATH = os.environ.get("http_path")
1743
ACCESS_TOKEN = os.environ.get("access_token")
18-
CATALOG = os.environ.get("catalog")
19-
SCHEMA = os.environ.get("schema")
44+
CATALOG = catalog or os.environ.get("catalog")
45+
SCHEMA = schema or os.environ.get("schema")
46+
47+
ua_connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}
48+
49+
if sqlalchemy_1_3():
50+
conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}"
51+
connect_args = {
52+
**ua_connect_args,
53+
"http_path": HTTP_PATH,
54+
"server_hostname": HOST,
55+
"catalog": CATALOG,
56+
"schema": SCHEMA,
57+
}
58+
59+
return conn_string, connect_args
60+
else:
61+
return (
62+
f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}",
63+
ua_connect_args,
64+
)
65+
66+
67+
@pytest.fixture
68+
def db_engine() -> Engine:
69+
conn_string, connect_args = version_agnostic_connect_arguments()
70+
return create_engine(conn_string, connect_args=connect_args)
2071

21-
connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}
2272

23-
engine = create_engine(
24-
f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}",
25-
connect_args=connect_args,
73+
@pytest.fixture
74+
def samples_engine() -> Engine:
75+
76+
conn_string, connect_args = version_agnostic_connect_arguments(
77+
catalog="samples", schema="nyctaxi"
2678
)
27-
return engine
79+
return create_engine(conn_string, connect_args=connect_args)
2880

2981

3082
@pytest.fixture()
@@ -62,6 +114,7 @@ def test_connect_args(db_engine):
62114
assert expected in user_agent
63115

64116

117+
@pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4")
65118
def test_pandas_upload(db_engine, metadata_obj):
66119

67120
import pandas as pd
@@ -86,7 +139,7 @@ def test_pandas_upload(db_engine, metadata_obj):
86139
db_engine.execute("DROP TABLE mock_data")
87140

88141

89-
def test_create_table_not_null(db_engine, metadata_obj):
142+
def test_create_table_not_null(db_engine, metadata_obj: MetaData):
90143

91144
table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))
92145

@@ -95,7 +148,7 @@ def test_create_table_not_null(db_engine, metadata_obj):
95148
metadata_obj,
96149
Column("name", String(255)),
97150
Column("episodes", Integer),
98-
Column("some_bool", BOOLEAN, nullable=False),
151+
Column("some_bool", BOOLEAN(create_constraint=False), nullable=False),
99152
)
100153

101154
metadata_obj.create_all()
@@ -135,7 +188,7 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
135188
metadata_obj.create_all()
136189
db_engine.execute(insert(SampleTable).values(rows))
137190

138-
rows = db_engine.execute(select(SampleTable)).fetchall()
191+
rows = db_engine.execute(version_agnostic_select(SampleTable)).fetchall()
139192

140193
assert len(rows) == num_to_insert
141194

@@ -148,7 +201,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
148201
metadata_obj,
149202
Column("name", String(255)),
150203
Column("episodes", Integer),
151-
Column("some_bool", BOOLEAN),
204+
Column("some_bool", BOOLEAN(create_constraint=False)),
152205
Column("dollars", DECIMAL(10, 2)),
153206
)
154207

@@ -161,7 +214,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
161214
with db_engine.connect() as conn:
162215
conn.execute(insert_stmt)
163216

164-
select_stmt = select(SampleTable)
217+
select_stmt = version_agnostic_select(SampleTable)
165218
resp = db_engine.execute(select_stmt)
166219

167220
result = resp.fetchall()
@@ -187,7 +240,7 @@ class SampleObject(base):
187240

188241
name = Column(String(255), primary_key=True)
189242
episodes = Column(Integer)
190-
some_bool = Column(BOOLEAN)
243+
some_bool = Column(BOOLEAN(create_constraint=False))
191244

192245
base.metadata.create_all()
193246

@@ -197,11 +250,15 @@ class SampleObject(base):
197250
session.add(sample_object_2)
198251
session.commit()
199252

200-
stmt = select(SampleObject).where(
253+
stmt = version_agnostic_select(SampleObject).where(
201254
SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"])
202255
)
203256

204-
output = [i for i in session.scalars(stmt)]
257+
if sqlalchemy_1_3():
258+
output = [i for i in session.execute(stmt)]
259+
else:
260+
output = [i for i in session.scalars(stmt)]
261+
205262
assert len(output) == 2
206263

207264
base.metadata.drop_all()
@@ -215,7 +272,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
215272
metadata_obj,
216273
Column("string_example", String(255)),
217274
Column("integer_example", Integer),
218-
Column("boolean_example", BOOLEAN),
275+
Column("boolean_example", BOOLEAN(create_constraint=False)),
219276
Column("decimal_example", DECIMAL(10, 2)),
220277
Column("date_example", Date),
221278
)
@@ -239,7 +296,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
239296
with db_engine.connect() as conn:
240297
conn.execute(insert_stmt)
241298

242-
select_stmt = select(SampleTable)
299+
select_stmt = version_agnostic_select(SampleTable)
243300
resp = db_engine.execute(select_stmt)
244301

245302
result = resp.fetchall()
@@ -252,3 +309,34 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
252309
assert this_row["date_example"] == date_example
253310

254311
metadata_obj.drop_all()
312+
313+
314+
def test_inspector_smoke_test(samples_engine: Engine):
315+
"""It does not appear that 3L namespace is supported here"""
316+
317+
from sqlalchemy.engine.reflection import Inspector
318+
319+
schema, table = "nyctaxi", "trips"
320+
321+
try:
322+
inspector = Inspector.from_engine(samples_engine)
323+
except Exception as e:
324+
assert False, f"Could not build inspector: {e}"
325+
326+
# Expect six columns
327+
columns = inspector.get_columns(table, schema=schema)
328+
329+
# Expect zero views, but the method should return
330+
views = inspector.get_view_names(schema=schema)
331+
332+
assert (
333+
len(columns) == 6
334+
), "Dialect did not find the expected number of columns in samples.nyctaxi.trips"
335+
assert len(views) == 0, "Views could not be fetched"
336+
337+
338+
def test_get_table_names_smoke_test(samples_engine: Engine):
339+
340+
with samples_engine.connect() as conn:
341+
_names = samples_engine.table_names(schema="nyctaxi", connection=conn)
342+
_names is not None, "get_table_names did not succeed"

0 commit comments

Comments
 (0)