Skip to content

Commit 1849b96

Browse files
john-bodleyebyhr
authored andcommitted
Ensure SQLAlchemy get_table_names only returns tables and not views
Per the [SQLAlchemy documentation](https://docs.sqlalchemy.org/en/14/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_table_names) the get_table_names method is intended to return only table names, i.e., exclude views.
1 parent 5041972 commit 1849b96

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"pytest",
4040
"pytest-runner",
4141
"click",
42+
"sqlalchemy_utils",
4243
]
4344

4445
setup(

tests/integration/test_sqlalchemy_integration.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytest
1313
import sqlalchemy as sqla
1414
from sqlalchemy.sql import and_, or_, not_
15+
from sqlalchemy_utils import create_view
1516

1617
from tests.unit.conftest import sqlalchemy_version
1718
from trino.sqlalchemy.datatype import JSON
@@ -368,3 +369,43 @@ def test_get_table_comment(trino_connection):
368369
assert actual['text'] is None
369370
finally:
370371
metadata.drop_all(engine)
372+
373+
374+
@pytest.mark.parametrize('trino_connection', ['memory/test'], indirect=True)
375+
@pytest.mark.parametrize('schema', [None, 'test'])
376+
def test_get_table_names(trino_connection, schema):
377+
engine, conn = trino_connection
378+
name = schema or engine.dialect._get_default_schema_name(conn)
379+
metadata = sqla.MetaData(schema=name)
380+
381+
if not engine.dialect.has_schema(conn, name):
382+
engine.execute(sqla.schema.CreateSchema(name))
383+
384+
try:
385+
create_view(
386+
'my_view',
387+
sqla.select(
388+
[
389+
sqla.Table(
390+
'my_table',
391+
metadata,
392+
sqla.Column('id', sqla.Integer),
393+
),
394+
],
395+
),
396+
metadata,
397+
cascade_on_drop=False,
398+
)
399+
400+
metadata.create_all(engine)
401+
assert sqla.inspect(engine).get_table_names(schema) == ['my_table']
402+
finally:
403+
metadata.drop_all(engine)
404+
405+
406+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
407+
def test_get_table_names_raises(trino_connection):
408+
engine, _ = trino_connection
409+
410+
with pytest.raises(sqla.exc.NoSuchTableError):
411+
sqla.inspect(engine).get_table_names(None)

trino/sqlalchemy/dialect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L
200200
SELECT "table_name"
201201
FROM "information_schema"."tables"
202202
WHERE "table_schema" = :schema
203+
AND "table_type" = 'BASE TABLE'
203204
"""
204205
).strip()
205206
res = connection.execute(sql.text(query), schema=schema)

0 commit comments

Comments
 (0)