Skip to content

Commit 9c8cbed

Browse files
committed
Add support for table comments
Signed-off-by: Christophe Bornet <[email protected]>
1 parent 62eb1d4 commit 9c8cbed

File tree

5 files changed

+138
-48
lines changed

5 files changed

+138
-48
lines changed

src/databricks/sqlalchemy/_ddl.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from sqlalchemy.sql import compiler
2+
from sqlalchemy.sql import compiler, sqltypes
33
import logging
44

55
logger = logging.getLogger(__name__)
@@ -16,7 +16,13 @@ def __init__(self, dialect):
1616

1717
class DatabricksDDLCompiler(compiler.DDLCompiler):
1818
def post_create_table(self, table):
19-
return " USING DELTA"
19+
post = " USING DELTA"
20+
if table.comment:
21+
comment = self.sql_compiler.render_literal_value(
22+
table.comment, sqltypes.String()
23+
)
24+
post += " COMMENT " + comment
25+
return post
2026

2127
def visit_unique_constraint(self, constraint, **kw):
2228
logger.warning("Databricks does not support unique constraints")
@@ -39,13 +45,26 @@ def visit_identity_column(self, identity, **kw):
3945
)
4046
return text
4147

48+
def visit_set_table_comment(self, create, **kw):
49+
return "ALTER TABLE %s COMMENT %s" % (
50+
self.preparer.format_table(create.element),
51+
self.sql_compiler.render_literal_value(
52+
create.element.comment, sqltypes.String()
53+
),
54+
)
55+
56+
def visit_drop_table_comment(self, create, **kw):
57+
return "ALTER TABLE %s COMMENT ''" % (
58+
self.preparer.format_table(create.element)
59+
)
60+
4261
def get_column_specification(self, column, **kwargs):
4362
"""Currently we override this method only to emit a log message if a user attempts to set
4463
autoincrement=True on a column. See comments in test_suite.py. We may implement implicit
4564
IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect.
4665
"""
4766
if column is column.table._autoincrement_column or column.autoincrement is True:
48-
logger.warn(
67+
logger.warning(
4968
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
5069
)
5170

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from sqlalchemy import MetaData, Table, Column, String, DateTime
2+
3+
metadata = MetaData()
4+
5+
tables = Table(
6+
"tables",
7+
metadata,
8+
Column("table_catalog", String),
9+
Column("table_schema", String),
10+
Column("table_name", String),
11+
Column("table_type", String),
12+
Column("is_insertable_into", String),
13+
Column("commit_action", String),
14+
Column("table_owner", String),
15+
Column("comment", String),
16+
Column("created", DateTime),
17+
Column("created_by", String),
18+
Column("last_altered", DateTime),
19+
Column("last_altered_by", String),
20+
Column("data_source_format", String),
21+
Column("storage_sub_directory", String),
22+
schema="information_schema",
23+
)

src/databricks/sqlalchemy/base.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import re
2-
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
1+
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
32

43
import databricks.sqlalchemy._ddl as dialect_ddl_impl
54
import databricks.sqlalchemy._types as dialect_type_impl
65
from databricks import sql
6+
from databricks.sqlalchemy._information_schema import tables
77
from databricks.sqlalchemy._parse import (
88
_describe_table_extended_result_to_dict_list,
99
_match_table_not_found_string,
@@ -15,15 +15,16 @@
1515
)
1616

1717
import sqlalchemy
18-
from sqlalchemy import DDL, event
18+
from sqlalchemy import DDL, event, select, bindparam, exc
1919
from sqlalchemy.engine import Connection, Engine, default, reflection
20-
from sqlalchemy.engine.reflection import ObjectKind
2120
from sqlalchemy.engine.interfaces import (
2221
ReflectedForeignKeyConstraint,
2322
ReflectedPrimaryKeyConstraint,
2423
ReflectedColumn,
24+
ReflectedTableComment,
2525
TableKey,
2626
)
27+
from sqlalchemy.engine.reflection import ReflectionDefaults, ObjectKind, ObjectScope
2728
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
2829

2930
try:
@@ -285,7 +286,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
285286
views_result = self.get_view_names(connection=connection, schema=schema)
286287

287288
# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
288-
# Potential optimisation: rewrite this to instead query informtation_schema
289+
# Potential optimisation: rewrite this to instead query information_schema
289290
tables_minus_views = [
290291
row.tableName for row in tables_result if row.tableName not in views_result
291292
]
@@ -328,7 +329,7 @@ def get_materialized_view_names(
328329
def get_temp_view_names(
329330
self, connection: Connection, schema: Optional[str] = None, **kw: Any
330331
) -> List[str]:
331-
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
332+
"""A wrapper around get_view_names that fetches only the names of temporary views"""
332333
return self.get_view_names(connection, schema, only_temp=True)
333334

334335
def do_rollback(self, dbapi_connection):
@@ -375,6 +376,87 @@ def get_schema_names(self, connection, **kw):
375376
schema_list = [row[0] for row in result]
376377
return schema_list
377378

379+
def get_multi_table_comment(
380+
self,
381+
connection,
382+
schema=None,
383+
filter_names=None,
384+
scope=ObjectScope.ANY,
385+
kind=ObjectKind.ANY,
386+
**kw,
387+
) -> Iterable[Tuple[TableKey, ReflectedTableComment]]:
388+
result = []
389+
_schema = schema or self.schema
390+
if ObjectScope.DEFAULT in scope:
391+
query = (
392+
select(tables.c.table_name, tables.c.comment)
393+
.select_from(tables)
394+
.where(
395+
tables.c.table_catalog == self.catalog,
396+
tables.c.table_schema == _schema,
397+
)
398+
)
399+
400+
if ObjectKind.ANY not in kind:
401+
where_in = set()
402+
if ObjectKind.TABLE in kind:
403+
where_in.update(
404+
["BASE TABLE", "MANAGED", "EXTERNAL", "STREAMING_TABLE"]
405+
)
406+
if ObjectKind.VIEW in kind:
407+
where_in.update(["VIEW"])
408+
if ObjectKind.MATERIALIZED_VIEW in kind:
409+
where_in.update(["MATERIALIZED_VIEW"])
410+
query = query.where(tables.c.table_type.in_(where_in))
411+
412+
if filter_names:
413+
query = query.where(tables.c.table_name.in_(bindparam("filter_names")))
414+
result = connection.execute(
415+
query, {"filter_names": [f.lower() for f in filter_names]}
416+
)
417+
else:
418+
result = connection.execute(query)
419+
420+
if ObjectScope.TEMPORARY in scope and ObjectKind.VIEW in kind:
421+
result = list(result)
422+
temp_views = self.get_view_names(connection, schema, only_temp=True)
423+
if filter_names:
424+
temp_views = set(temp_views).intersection(
425+
[f.lower() for f in filter_names]
426+
)
427+
result.extend(zip(temp_views, [None] * len(temp_views)))
428+
429+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
430+
return (
431+
(
432+
(schema, table),
433+
{"text": comment}
434+
if comment is not None
435+
else ReflectionDefaults.table_comment(),
436+
)
437+
for table, comment in result
438+
) # type: ignore
439+
440+
def get_table_comment(
441+
self,
442+
connection: Connection,
443+
table_name: str,
444+
schema: Optional[str] = None,
445+
**kw: Any,
446+
) -> ReflectedTableComment:
447+
data = self.get_multi_table_comment(
448+
connection,
449+
schema,
450+
[table_name],
451+
**kw,
452+
)
453+
try:
454+
return dict(data)[(schema, table_name.lower())]
455+
except KeyError:
456+
raise exc.NoSuchTableError(
457+
f"{schema}.{table_name}" if schema else table_name
458+
) from None
459+
378460

379461
@event.listens_for(Engine, "do_connect")
380462
def receive_do_connect(dialect, conn_rec, cargs, cparams):

src/databricks/sqlalchemy/requirements.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def table_reflection(self):
159159
"""target database has general support for table reflection"""
160160
return sqlalchemy.testing.exclusions.open()
161161

162+
@property
163+
def comment_reflection(self):
164+
"""Indicates if the database support table comment reflection"""
165+
return sqlalchemy.testing.exclusions.open()
166+
162167
@property
163168
def temp_table_reflection(self):
164169
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.

src/databricks/sqlalchemy/test/_future.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
ComponentReflectionTest,
1414
ComponentReflectionTestExtra,
1515
CTETest,
16-
FutureTableDDLTest,
1716
InsertBehaviorTest,
18-
TableDDLTest,
1917
)
2018
from sqlalchemy.testing.suite import (
2119
ArrayTest,
@@ -251,36 +249,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
251249
pass
252250

253251

254-
class FutureTableDDLTest(FutureTableDDLTest):
255-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
256-
def test_add_table_comment(self):
257-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
258-
pass
259-
260-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
261-
def test_drop_table_comment(self):
262-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
263-
pass
264-
265-
266-
class TableDDLTest(TableDDLTest):
267-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
268-
def test_add_table_comment(self, connection):
269-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
270-
pass
271-
272-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
273-
def test_drop_table_comment(self, connection):
274-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
275-
pass
276-
277-
278252
class ComponentReflectionTest(ComponentReflectionTest):
279-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
280-
def test_get_multi_table_comment(self):
281-
"""There are 84 permutations of this test that are skipped."""
282-
pass
283-
284253
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
285254
def test_multi_get_table_options_tables(self):
286255
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
@@ -302,14 +271,6 @@ def test_get_multi_pk_constraint(self):
302271
def test_get_multi_check_constraints(self):
303272
pass
304273

305-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
306-
def test_get_comments(self):
307-
pass
308-
309-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
310-
def test_get_comments_with_schema(self):
311-
pass
312-
313274
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
314275
def test_comments_unicode(self):
315276
pass

0 commit comments

Comments
 (0)