Skip to content

Commit aee6064

Browse files
laserkaplanhashhar
authored andcommitted
Add ability to specify catalog in SQLAlchemy Table objects
1 parent a0f524b commit aee6064

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

tests/unit/sqlalchemy/test_compiler.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,24 @@
1919
String,
2020
Table,
2121
)
22+
from sqlalchemy.schema import CreateTable
2223

2324
from trino.sqlalchemy.dialect import TrinoDialect
2425

2526
metadata = MetaData()
2627
table = Table(
2728
'table',
2829
metadata,
29-
Column('id', Integer, primary_key=True),
30+
Column('id', Integer),
3031
Column('name', String),
3132
)
33+
table_with_catalog = Table(
34+
'table',
35+
metadata,
36+
Column('id', Integer),
37+
schema='default',
38+
trino_catalog='other'
39+
)
3240

3341

3442
@pytest.fixture
@@ -64,3 +72,20 @@ def test_cte_insert_order(dialect):
6472
'FROM "table")\n'\
6573
' SELECT cte.id, cte.name \n'\
6674
'FROM cte'
75+
76+
77+
def test_catalogs_argument(dialect):
78+
statement = select(table_with_catalog)
79+
query = statement.compile(dialect=dialect)
80+
assert str(query) == 'SELECT default."table".id \nFROM "other".default."table"'
81+
82+
83+
def test_catalogs_create_table(dialect):
84+
statement = CreateTable(table_with_catalog)
85+
query = statement.compile(dialect=dialect)
86+
assert str(query) == \
87+
'\n'\
88+
'CREATE TABLE "other".default."table" (\n'\
89+
'\tid INTEGER\n'\
90+
')\n'\
91+
'\n'

trino/sqlalchemy/compiler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
from sqlalchemy.sql import compiler
13+
try:
14+
from sqlalchemy.sql.expression import (
15+
Alias,
16+
CTE,
17+
Subquery,
18+
)
19+
except ImportError:
20+
# For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist
21+
from sqlalchemy.sql.expression import Alias
22+
CTE = type(None)
23+
Subquery = type(None)
1324

1425
# https://trino.io/docs/current/language/reserved.html
1526
RESERVED_WORDS = {
@@ -102,6 +113,31 @@ def limit_clause(self, select, **kw):
102113
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
103114
return text
104115

116+
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
117+
fromhints=None, use_schema=True, **kwargs):
118+
sql = super(TrinoSQLCompiler, self).visit_table(
119+
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
120+
)
121+
return self.add_catalog(sql, table)
122+
123+
@staticmethod
124+
def add_catalog(sql, table):
125+
if table is None:
126+
return sql
127+
128+
if isinstance(table, (Alias, CTE, Subquery)):
129+
return sql
130+
131+
if (
132+
'trino' not in table.dialect_options
133+
or 'catalog' not in table.dialect_options['trino']
134+
):
135+
return sql
136+
137+
catalog = table.dialect_options['trino']['catalog']
138+
sql = f'"{catalog}".{sql}'
139+
return sql
140+
105141

106142
class TrinoDDLCompiler(compiler.DDLCompiler):
107143
pass
@@ -173,3 +209,7 @@ def visit_TIME(self, type_, **kw):
173209

174210
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
175211
reserved_words = RESERVED_WORDS
212+
213+
def format_table(self, table, use_schema=True, name=None):
214+
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
215+
return TrinoSQLCompiler.add_catalog(result, table)

0 commit comments

Comments
 (0)