Skip to content

Commit 29bf762

Browse files
mathiasritterhashhar
authored andcommitted
SQLAlchemy 2.0 Compatibility
The following changes are for compatibility with SQLAlchemy 2.0, while mainting compatibility with versions >= 1.3. In the Connection.execute function, query parameters may no longer be passed as keyword arguments. Parameters can be passed in a dicionary instead. Engine.execute was removed. Statements can only be executed on the Connection object, which can be obtained via Engine.begin() or Engine.connect(). RowProxy is no longer a “proxy”; is now called Row and behaves like an enhanced named tuple. In order to access rows, just use row.column instead of row["column"]. Raw SQL statements must be wrapped into a TextClause by calling text(...), imported from sqlalchemy.
1 parent a307c98 commit 29bf762

File tree

4 files changed

+46
-37
lines changed

4 files changed

+46
-37
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ jobs:
5858
- { python: "3.11", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
5959
# Test with sqlalchemy 1.3
6060
- { python: "3.11", trino: "latest", sqlalchemy: "~=1.3.0" }
61+
# Test with sqlalchemy 2.0
62+
- { python: "3.11", trino: "latest", sqlalchemy: "~=2.0.0rc1" }
6163
env:
6264
TRINO_VERSION: "${{ matrix.trino }}"
6365
steps:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
version = str(ast.literal_eval(trino_version.group(1)))
2727

2828
kerberos_require = ["requests_kerberos"]
29-
sqlalchemy_require = ["sqlalchemy~=1.3"]
29+
sqlalchemy_require = ["sqlalchemy >= 1.3"]
3030
external_authentication_token_cache_require = ["keyring"]
3131

3232
# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.

tests/integration/test_sqlalchemy_integration.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def test_select_query(trino_connection):
4343
rows = result.fetchall()
4444
assert len(rows) == 25
4545
for row in rows:
46-
assert isinstance(row['nationkey'], int)
47-
assert isinstance(row['name'], str)
48-
assert isinstance(row['regionkey'], int)
49-
assert isinstance(row['comment'], str)
46+
assert isinstance(row.nationkey, int)
47+
assert isinstance(row.name, str)
48+
assert isinstance(row.regionkey, int)
49+
assert isinstance(row.comment, str)
5050

5151

5252
def assert_column(table, column_name, column_type):
@@ -70,8 +70,8 @@ def test_select_specific_columns(trino_connection):
7070
rows = result.fetchall()
7171
assert len(rows) > 0
7272
for row in rows:
73-
assert isinstance(row['node_id'], str)
74-
assert isinstance(row['state'], str)
73+
assert isinstance(row.node_id, str)
74+
assert isinstance(row.state, str)
7575

7676

7777
@pytest.mark.skipif(
@@ -82,7 +82,8 @@ def test_select_specific_columns(trino_connection):
8282
def test_define_and_create_table(trino_connection):
8383
engine, conn = trino_connection
8484
if not engine.dialect.has_schema(conn, "test"):
85-
engine.execute(sqla.schema.CreateSchema("test"))
85+
with engine.begin() as connection:
86+
connection.execute(sqla.schema.CreateSchema("test"))
8687
metadata = sqla.MetaData()
8788
try:
8889
sqla.Table('users',
@@ -110,7 +111,8 @@ def test_insert(trino_connection):
110111
engine, conn = trino_connection
111112

112113
if not engine.dialect.has_schema(conn, "test"):
113-
engine.execute(sqla.schema.CreateSchema("test"))
114+
with engine.begin() as connection:
115+
connection.execute(sqla.schema.CreateSchema("test"))
114116
metadata = sqla.MetaData()
115117
try:
116118
users = sqla.Table('users',
@@ -139,7 +141,8 @@ def test_insert(trino_connection):
139141
def test_insert_multiple_statements(trino_connection):
140142
engine, conn = trino_connection
141143
if not engine.dialect.has_schema(conn, "test"):
142-
engine.execute(sqla.schema.CreateSchema("test"))
144+
with engine.begin() as connection:
145+
connection.execute(sqla.schema.CreateSchema("test"))
143146
metadata = sqla.MetaData()
144147
users = sqla.Table('users',
145148
metadata,
@@ -180,10 +183,10 @@ def test_operators(trino_connection):
180183
rows = result.fetchall()
181184
assert len(rows) == 1
182185
for row in rows:
183-
assert isinstance(row['nationkey'], int)
184-
assert isinstance(row['name'], str)
185-
assert isinstance(row['regionkey'], int)
186-
assert isinstance(row['comment'], str)
186+
assert isinstance(row.nationkey, int)
187+
assert isinstance(row.name, str)
188+
assert isinstance(row.regionkey, int)
189+
assert isinstance(row.comment, str)
187190

188191

189192
@pytest.mark.skipif(
@@ -216,14 +219,14 @@ def test_textual_sql(trino_connection):
216219
rows = result.fetchall()
217220
assert len(rows) == 3
218221
for row in rows:
219-
assert isinstance(row['custkey'], int)
220-
assert isinstance(row['name'], str)
221-
assert isinstance(row['address'], str)
222-
assert isinstance(row['nationkey'], int)
223-
assert isinstance(row['phone'], str)
224-
assert isinstance(row['acctbal'], float)
225-
assert isinstance(row['mktsegment'], str)
226-
assert isinstance(row['comment'], str)
222+
assert isinstance(row.custkey, int)
223+
assert isinstance(row.name, str)
224+
assert isinstance(row.address, str)
225+
assert isinstance(row.nationkey, int)
226+
assert isinstance(row.phone, str)
227+
assert isinstance(row.acctbal, float)
228+
assert isinstance(row.mktsegment, str)
229+
assert isinstance(row.comment, str)
227230

228231

229232
@pytest.mark.skipif(
@@ -323,7 +326,8 @@ def test_json_column(trino_connection, json_object):
323326
engine, conn = trino_connection
324327

325328
if not engine.dialect.has_schema(conn, "test"):
326-
engine.execute(sqla.schema.CreateSchema("test"))
329+
with engine.begin() as connection:
330+
connection.execute(sqla.schema.CreateSchema("test"))
327331
metadata = sqla.MetaData()
328332

329333
try:
@@ -351,7 +355,8 @@ def test_get_table_comment(trino_connection):
351355
engine, conn = trino_connection
352356

353357
if not engine.dialect.has_schema(conn, "test"):
354-
engine.execute(sqla.schema.CreateSchema("test"))
358+
with engine.begin() as connection:
359+
connection.execute(sqla.schema.CreateSchema("test"))
355360
metadata = sqla.MetaData()
356361

357362
try:
@@ -378,7 +383,8 @@ def test_get_table_names(trino_connection, schema):
378383
metadata = sqla.MetaData(schema=schema_name)
379384

380385
if not engine.dialect.has_schema(conn, schema_name):
381-
engine.execute(sqla.schema.CreateSchema(schema_name))
386+
with engine.begin() as connection:
387+
connection.execute(sqla.schema.CreateSchema(schema_name))
382388

383389
try:
384390
sqla.Table(
@@ -388,10 +394,10 @@ def test_get_table_names(trino_connection, schema):
388394
)
389395
metadata.create_all(engine)
390396
view_name = schema_name + ".test_view"
391-
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names")
397+
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names"))
392398
assert sqla.inspect(engine).get_table_names(schema_name) == ['test_get_table_names']
393399
finally:
394-
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
400+
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
395401
metadata.drop_all(engine)
396402

397403

@@ -411,7 +417,8 @@ def test_get_view_names(trino_connection, schema):
411417
metadata = sqla.MetaData(schema=schema_name)
412418

413419
if not engine.dialect.has_schema(conn, schema_name):
414-
engine.execute(sqla.schema.CreateSchema(schema_name))
420+
with engine.begin() as connection:
421+
connection.execute(sqla.schema.CreateSchema(schema_name))
415422

416423
try:
417424
sqla.Table(
@@ -421,10 +428,10 @@ def test_get_view_names(trino_connection, schema):
421428
)
422429
metadata.create_all(engine)
423430
view_name = schema_name + ".test_get_view_names"
424-
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_table")
431+
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_table"))
425432
assert sqla.inspect(engine).get_view_names(schema_name) == ['test_get_view_names']
426433
finally:
427-
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
434+
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
428435
metadata.drop_all(engine)
429436

430437

trino/sqlalchemy/dialect.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
156156
ORDER BY "ordinal_position" ASC
157157
"""
158158
).strip()
159-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
159+
res = connection.execute(sql.text(query), {"schema": schema, "table": table_name})
160160
columns = []
161161
for record in res:
162162
column = dict(
@@ -204,7 +204,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L
204204
AND "table_type" = 'BASE TABLE'
205205
"""
206206
).strip()
207-
res = connection.execute(sql.text(query), schema=schema)
207+
res = connection.execute(sql.text(query), {"schema": schema})
208208
return [row.table_name for row in res]
209209

210210
def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -225,7 +225,7 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li
225225
AND "table_type" = 'VIEW'
226226
"""
227227
).strip()
228-
res = connection.execute(sql.text(query), schema=schema)
228+
res = connection.execute(sql.text(query), {"schema": schema})
229229
return [row.table_name for row in res]
230230

231231
def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -244,7 +244,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
244244
AND "table_name" = :view
245245
"""
246246
).strip()
247-
res = connection.execute(sql.text(query), schema=schema, view=view_name)
247+
res = connection.execute(sql.text(query), {"schema": schema, "view": view_name})
248248
return res.scalar()
249249

250250
def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
@@ -296,7 +296,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
296296
try:
297297
res = connection.execute(
298298
sql.text(query),
299-
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
299+
{"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name}
300300
)
301301
return dict(text=res.scalar())
302302
except error.TrinoQueryError as e:
@@ -314,7 +314,7 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
314314
WHERE "schema_name" = :schema
315315
"""
316316
).strip()
317-
res = connection.execute(sql.text(query), schema=schema)
317+
res = connection.execute(sql.text(query), {"schema": schema})
318318
return res.first() is not None
319319

320320
def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool:
@@ -329,7 +329,7 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None,
329329
AND "table_name" = :table
330330
"""
331331
).strip()
332-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
332+
res = connection.execute(sql.text(query), {"schema": schema, "table": table_name})
333333
return res.first() is not None
334334

335335
def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:

0 commit comments

Comments
 (0)