Skip to content

Commit cfdf3f3

Browse files
committed
Add support for table comments
Signed-off-by: Christophe Bornet <[email protected]>
1 parent 62eb1d4 commit cfdf3f3

File tree

9 files changed

+182
-51
lines changed

9 files changed

+182
-51
lines changed

src/databricks/sqlalchemy/_ddl.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from sqlalchemy.sql import compiler
2+
from sqlalchemy.sql import compiler, sqltypes
33
import logging
44

55
logger = logging.getLogger(__name__)
@@ -16,7 +16,13 @@ def __init__(self, dialect):
1616

1717
class DatabricksDDLCompiler(compiler.DDLCompiler):
1818
def post_create_table(self, table):
19-
return " USING DELTA"
19+
post = " USING DELTA"
20+
if table.comment:
21+
comment = self.sql_compiler.render_literal_value(
22+
table.comment, sqltypes.String()
23+
)
24+
post += " COMMENT " + comment
25+
return post
2026

2127
def visit_unique_constraint(self, constraint, **kw):
2228
logger.warning("Databricks does not support unique constraints")
@@ -45,7 +51,7 @@ def get_column_specification(self, column, **kwargs):
4551
IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect.
4652
"""
4753
if column is column.table._autoincrement_column or column.autoincrement is True:
48-
logger.warn(
54+
logger.warning(
4955
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
5056
)
5157

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from sqlalchemy import MetaData, Table, Column, String, DateTime
2+
3+
metadata = MetaData()
4+
5+
tables = Table(
6+
"tables",
7+
metadata,
8+
Column("table_catalog", String),
9+
Column("table_schema", String),
10+
Column("table_name", String),
11+
Column("table_type", String),
12+
Column("is_insertable_into", String),
13+
Column("commit_action", String),
14+
Column("table_owner", String),
15+
Column("comment", String),
16+
Column("created", DateTime),
17+
Column("created_by", String),
18+
Column("last_altered", DateTime),
19+
Column("last_altered_by", String),
20+
Column("data_source_format", String),
21+
Column("storage_sub_directory", String),
22+
schema="information_schema",
23+
)

src/databricks/sqlalchemy/base.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import re
2-
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
1+
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
32

43
import databricks.sqlalchemy._ddl as dialect_ddl_impl
54
import databricks.sqlalchemy._types as dialect_type_impl
65
from databricks import sql
6+
from databricks.sqlalchemy._information_schema import tables
77
from databricks.sqlalchemy._parse import (
88
_describe_table_extended_result_to_dict_list,
99
_match_table_not_found_string,
@@ -15,15 +15,16 @@
1515
)
1616

1717
import sqlalchemy
18-
from sqlalchemy import DDL, event
18+
from sqlalchemy import DDL, event, select, bindparam, exc
1919
from sqlalchemy.engine import Connection, Engine, default, reflection
20-
from sqlalchemy.engine.reflection import ObjectKind
2120
from sqlalchemy.engine.interfaces import (
2221
ReflectedForeignKeyConstraint,
2322
ReflectedPrimaryKeyConstraint,
2423
ReflectedColumn,
24+
ReflectedTableComment,
2525
TableKey,
2626
)
27+
from sqlalchemy.engine.reflection import ReflectionDefaults, ObjectKind, ObjectScope
2728
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
2829

2930
try:
@@ -285,7 +286,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
285286
views_result = self.get_view_names(connection=connection, schema=schema)
286287

287288
# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
288-
# Potential optimisation: rewrite this to instead query informtation_schema
289+
# Potential optimisation: rewrite this to instead query information_schema
289290
tables_minus_views = [
290291
row.tableName for row in tables_result if row.tableName not in views_result
291292
]
@@ -328,7 +329,7 @@ def get_materialized_view_names(
328329
def get_temp_view_names(
329330
self, connection: Connection, schema: Optional[str] = None, **kw: Any
330331
) -> List[str]:
331-
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
332+
"""A wrapper around get_view_names that fetches only the names of temporary views"""
332333
return self.get_view_names(connection, schema, only_temp=True)
333334

334335
def do_rollback(self, dbapi_connection):
@@ -375,6 +376,87 @@ def get_schema_names(self, connection, **kw):
375376
schema_list = [row[0] for row in result]
376377
return schema_list
377378

379+
def get_multi_table_comment(
380+
self,
381+
connection,
382+
schema=None,
383+
filter_names=None,
384+
scope=ObjectScope.ANY,
385+
kind=ObjectKind.ANY,
386+
**kw,
387+
) -> Iterable[Tuple[TableKey, ReflectedTableComment]]:
388+
result = []
389+
_schema = schema or self.schema
390+
if ObjectScope.DEFAULT in scope:
391+
query = (
392+
select(tables.c.table_name, tables.c.comment)
393+
.select_from(tables)
394+
.where(
395+
tables.c.table_catalog == self.catalog,
396+
tables.c.table_schema == _schema,
397+
)
398+
)
399+
400+
if ObjectKind.ANY not in kind:
401+
where_in = set()
402+
if ObjectKind.TABLE in kind:
403+
where_in.update(
404+
["BASE TABLE", "MANAGED", "EXTERNAL", "STREAMING_TABLE"]
405+
)
406+
if ObjectKind.VIEW in kind:
407+
where_in.update(["VIEW"])
408+
if ObjectKind.MATERIALIZED_VIEW in kind:
409+
where_in.update(["MATERIALIZED_VIEW"])
410+
query = query.where(tables.c.table_type.in_(where_in))
411+
412+
if filter_names:
413+
query = query.where(tables.c.table_name.in_(bindparam("filter_names")))
414+
result = connection.execute(
415+
query, {"filter_names": [f.lower() for f in filter_names]}
416+
)
417+
else:
418+
result = connection.execute(query)
419+
420+
if ObjectScope.TEMPORARY in scope and ObjectKind.VIEW in kind:
421+
result = list(result)
422+
temp_views = self.get_view_names(connection, schema, only_temp=True)
423+
if filter_names:
424+
temp_views = set(temp_views).intersection(
425+
[f.lower() for f in filter_names]
426+
)
427+
result.extend(zip(temp_views, [None] * len(temp_views)))
428+
429+
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
430+
return (
431+
(
432+
(schema, table),
433+
{"text": comment}
434+
if comment is not None
435+
else ReflectionDefaults.table_comment(),
436+
)
437+
for table, comment in result
438+
) # type: ignore
439+
440+
def get_table_comment(
441+
self,
442+
connection: Connection,
443+
table_name: str,
444+
schema: Optional[str] = None,
445+
**kw: Any,
446+
) -> ReflectedTableComment:
447+
data = self.get_multi_table_comment(
448+
connection,
449+
schema,
450+
[table_name],
451+
**kw,
452+
)
453+
try:
454+
return dict(data)[(schema, table_name.lower())]
455+
except KeyError:
456+
raise exc.NoSuchTableError(
457+
f"{schema}.{table_name}" if schema else table_name
458+
) from None
459+
378460

379461
@event.listens_for(Engine, "do_connect")
380462
def receive_do_connect(dialect, conn_rec, cargs, cparams):

src/databricks/sqlalchemy/pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[pytest]
22
markers =
3-
reviewed: Test case has been reviewed by databricks
3+
reviewed: Test case has been reviewed by databricks

src/databricks/sqlalchemy/requirements.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def table_reflection(self):
159159
"""target database has general support for table reflection"""
160160
return sqlalchemy.testing.exclusions.open()
161161

162+
@property
163+
def comment_reflection(self):
164+
"""Indicates if the database support table comment reflection"""
165+
return sqlalchemy.testing.exclusions.open()
166+
162167
@property
163168
def temp_table_reflection(self):
164169
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.

src/databricks/sqlalchemy/test/_extra.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import datetime
55

6+
from sqlalchemy import Integer, String, schema, inspect
7+
from sqlalchemy.testing import util
8+
from sqlalchemy.testing.config import requirements
69
from sqlalchemy.testing.suite.test_types import (
710
_LiteralRoundTripFixture,
811
fixtures,
@@ -63,8 +66,39 @@ class DateTimeTZTestCustom(_DateFixture, fixtures.TablesTest):
6366

6467
@testing.requires.datetime_implicit_bound
6568
def test_select_direct(self, connection):
66-
6769
# We need to pass the TIMESTAMP type to the literal function
6870
# so that the value is processed correctly.
6971
result = connection.scalar(select(literal(self.data, TIMESTAMP)))
7072
eq_(result, self.data)
73+
74+
75+
class TableDDLTestCustom(fixtures.TestBase):
76+
"""This test confirms that a table comment can be dropped.
77+
The difference with TableDDLTest is that the comment value is '' and not None after
78+
being dropped.
79+
"""
80+
81+
__backend__ = True
82+
83+
def _simple_fixture(self, schema=None):
84+
return Table(
85+
"test_table",
86+
self.metadata,
87+
Column("id", Integer, primary_key=True, autoincrement=False),
88+
Column("data", String(50)),
89+
schema=schema,
90+
)
91+
92+
@requirements.comment_reflection
93+
@util.provide_metadata
94+
def test_drop_table_comment(self, connection):
95+
table = self._simple_fixture()
96+
table.create(connection, checkfirst=False)
97+
table.comment = "a comment"
98+
connection.execute(schema.SetTableComment(table))
99+
connection.execute(schema.DropTableComment(table))
100+
eq_(inspect(connection).get_table_comment("test_table"), {"text": ""})
101+
102+
103+
class FutureTableDDLTestCustom(fixtures.FutureEngineMixin, TableDDLTestCustom):
104+
pass

src/databricks/sqlalchemy/test/_future.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
ComponentReflectionTest,
1414
ComponentReflectionTestExtra,
1515
CTETest,
16-
FutureTableDDLTest,
1716
InsertBehaviorTest,
18-
TableDDLTest,
1917
)
2018
from sqlalchemy.testing.suite import (
2119
ArrayTest,
@@ -251,36 +249,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
251249
pass
252250

253251

254-
class FutureTableDDLTest(FutureTableDDLTest):
255-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
256-
def test_add_table_comment(self):
257-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
258-
pass
259-
260-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
261-
def test_drop_table_comment(self):
262-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
263-
pass
264-
265-
266-
class TableDDLTest(TableDDLTest):
267-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
268-
def test_add_table_comment(self, connection):
269-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
270-
pass
271-
272-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
273-
def test_drop_table_comment(self, connection):
274-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
275-
pass
276-
277-
278252
class ComponentReflectionTest(ComponentReflectionTest):
279-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
280-
def test_get_multi_table_comment(self):
281-
"""There are 84 permutations of this test that are skipped."""
282-
pass
283-
284253
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
285254
def test_multi_get_table_options_tables(self):
286255
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
@@ -302,14 +271,6 @@ def test_get_multi_pk_constraint(self):
302271
def test_get_multi_check_constraints(self):
303272
pass
304273

305-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
306-
def test_get_comments(self):
307-
pass
308-
309-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
310-
def test_get_comments_with_schema(self):
311-
pass
312-
313274
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
314275
def test_comments_unicode(self):
315276
pass

src/databricks/sqlalchemy/test/_unsupported.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class SkipReason(Enum):
5555
TIMEZONE_OPT = "timezone-optional TIMESTAMP fields"
5656
TRANSACTIONS = "transactions"
5757
UNIQUE = "UNIQUE constraints"
58+
DROP_TBL = "drop table comment"
5859

5960

6061
def render_skip_reason(rsn: SkipReason, setup_error=False, extra=False) -> str:
@@ -222,6 +223,13 @@ def test_uuid_returning(self):
222223

223224

224225
class FutureTableDDLTest(FutureTableDDLTest):
226+
@pytest.mark.skip(reason=render_skip_reason(SkipReason.DROP_TBL))
227+
def test_drop_table_comment(self, connection):
228+
"""The DropTableComment statement is supported but it sets the comment to ''
229+
instead of None so this test can't pass.
230+
"""
231+
pass
232+
225233
@pytest.mark.skip(render_skip_reason(SkipReason.INDEXES))
226234
def test_create_index_if_not_exists(self):
227235
"""We could use requirements.index_reflection and requirements.index_ddl_if_exists
@@ -238,6 +246,13 @@ def test_drop_index_if_exists(self):
238246

239247

240248
class TableDDLTest(TableDDLTest):
249+
@pytest.mark.skip(reason=render_skip_reason(SkipReason.DROP_TBL))
250+
def test_drop_table_comment(self, connection):
251+
"""The DropTableComment statement is supported but it sets the comment to ''
252+
instead of None so this test can't pass.
253+
"""
254+
pass
255+
241256
@pytest.mark.skip(reason=render_skip_reason(SkipReason.INDEXES))
242257
def test_create_index_if_not_exists(self, connection):
243258
"""We could use requirements.index_reflection and requirements.index_ddl_if_exists

src/databricks/sqlalchemy/test/test_suite.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@
1010
from databricks.sqlalchemy.test._regression import *
1111
from databricks.sqlalchemy.test._unsupported import *
1212
from databricks.sqlalchemy.test._future import *
13-
from databricks.sqlalchemy.test._extra import TinyIntegerTest, DateTimeTZTestCustom
13+
from databricks.sqlalchemy.test._extra import (
14+
TinyIntegerTest,
15+
DateTimeTZTestCustom,
16+
TableDDLTestCustom,
17+
FutureTableDDLTestCustom
18+
)

0 commit comments

Comments
 (0)