Skip to content

Add support for table comments #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 23, 2024
10 changes: 8 additions & 2 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def __init__(self, dialect):

class DatabricksDDLCompiler(compiler.DDLCompiler):
def post_create_table(self, table):
return " USING DELTA"
post = " USING DELTA"
if table.comment:
comment = self.sql_compiler.render_literal_value(
table.comment, sqltypes.String()
)
post += " COMMENT " + comment
return post

def visit_unique_constraint(self, constraint, **kw):
logger.warning("Databricks does not support unique constraints")
Expand Down Expand Up @@ -61,7 +67,7 @@ def get_column_specification(self, column, **kwargs):
feature in the future, similar to the Microsoft SQL Server dialect.
"""
if column is column.table._autoincrement_column or column.autoincrement is True:
logger.warn(
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 🙏

"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
)

Expand Down
23 changes: 23 additions & 0 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
return output_rows


def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name`
value contains the match argument.
"""

output_rows = []

for row_dict in dte_output:
if match in row_dict["col_name"]:
output_rows.append(row_dict)

return output_rows


def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
one dictionary per defined constraint
Expand All @@ -275,6 +289,15 @@ def get_pk_strings_from_dte_output(
return output


def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]:
"""Returns the value of the first "Comment" col_name data in dte_output"""
output = match_dte_rows_by_key(dte_output, "Comment")
if not output:
return None
else:
return output[0]["data_type"]


# The keys of this dictionary are the values we expect to see in a
# TGetColumnsRequest's .TYPE_NAME attribute.
# These are enumerated in ttypes.py as class TTypeId.
Expand Down
36 changes: 30 additions & 6 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
Expand All @@ -11,19 +10,20 @@
build_pk_dict,
get_fk_strings_from_dte_output,
get_pk_strings_from_dte_output,
get_comment_from_dte_output,
parse_column_info_from_tgetcolumnsresponse,
)

import sqlalchemy
from sqlalchemy import DDL, event
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.reflection import ObjectKind
from sqlalchemy.engine.interfaces import (
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedColumn,
TableKey,
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

try:
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
views_result = self.get_view_names(connection=connection, schema=schema)

# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
# Potential optimisation: rewrite this to instead query informtation_schema
# Potential optimisation: rewrite this to instead query information_schema
tables_minus_views = [
row.tableName for row in tables_result if row.tableName not in views_result
]
Expand Down Expand Up @@ -328,7 +328,7 @@ def get_materialized_view_names(
def get_temp_view_names(
self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
"""A wrapper around get_view_names that fetches only the names of temporary views"""
return self.get_view_names(connection, schema, only_temp=True)

def do_rollback(self, dbapi_connection):
Expand Down Expand Up @@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw):
schema_list = [row[0] for row in result]
return schema_list

@reflection.cache
def get_table_comment(
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> ReflectedTableComment:
result = self._describe_table_extended(
connection=connection,
table_name=table_name,
schema_name=schema,
)

if result is None:
return ReflectionDefaults.table_comment()

comment = get_comment_from_dte_output(result)

if comment:
return dict(text=comment)
else:
return ReflectionDefaults.table_comment()


@event.listens_for(Engine, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/sqlalchemy/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def table_reflection(self):
"""target database has general support for table reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection(self):
"""Indicates if the database support table comment reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection_full_unicode(self):
"""Indicates if the database support table comment reflection in the
full unicode range, including emoji etc.
"""
return sqlalchemy.testing.exclusions.open()

@property
def temp_table_reflection(self):
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.
Expand Down
48 changes: 0 additions & 48 deletions src/databricks/sqlalchemy/test/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
ComponentReflectionTest,
ComponentReflectionTestExtra,
CTETest,
FutureTableDDLTest,
InsertBehaviorTest,
TableDDLTest,
)
from sqlalchemy.testing.suite import (
ArrayTest,
Expand Down Expand Up @@ -53,7 +51,6 @@ class FutureFeature(Enum):
PROVISION = "event-driven engine configuration"
REGEXP = "_visit_regexp"
SANE_ROWCOUNT = "sane_rowcount support"
TBL_COMMENTS = "table comment reflection"
TBL_OPTS = "get_table_options method"
TEST_DESIGN = "required test-fixture overrides"
TUPLE_LITERAL = "tuple-like IN markers completely"
Expand Down Expand Up @@ -251,36 +248,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
pass


class FutureTableDDLTest(FutureTableDDLTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class TableDDLTest(TableDDLTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class ComponentReflectionTest(ComponentReflectionTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_multi_table_comment(self):
"""There are 84 permutations of this test that are skipped."""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
def test_multi_get_table_options_tables(self):
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
Expand All @@ -302,22 +270,6 @@ def test_get_multi_pk_constraint(self):
def test_get_multi_check_constraints(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments_with_schema(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode_full(self):
pass


class ComponentReflectionTestExtra(ComponentReflectionTestExtra):
@pytest.mark.skip(render_future_feature(FutureFeature.CHECK))
Expand Down
48 changes: 45 additions & 3 deletions src/databricks/sqlalchemy/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import decimal
import os
from typing import Tuple, Union
from typing import Tuple, Union, List
from unittest import skipIf

import pytest
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_column_comment(db_engine, metadata_obj: MetaData):
connection=connection, table_name=table_name
)

assert columns[0].get("comment") == ""
assert columns[0].get("comment") == None

metadata_obj.drop_all(db_engine)

Expand Down Expand Up @@ -477,7 +477,7 @@ def sample_table(metadata_obj: MetaData, db_engine: Engine):

table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))

args = [
args: List[Column] = [
Column(colname, coltype) for colname, coltype in GET_COLUMNS_TYPE_MAP.items()
]

Expand All @@ -499,3 +499,45 @@ def test_get_columns(db_engine, sample_table: str):
columns = inspector.get_columns(sample_table)

assert True


class TestCommentReflection:
@pytest.fixture(scope="class")
def engine(self):
HOST = os.environ.get("host")
HTTP_PATH = os.environ.get("http_path")
ACCESS_TOKEN = os.environ.get("access_token")
CATALOG = os.environ.get("catalog")
SCHEMA = os.environ.get("schema")

connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}"
connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}

engine = create_engine(connection_string, connect_args=connect_args)
return engine

@pytest.fixture
def inspector(self, engine: Engine) -> Inspector:
return Inspector.from_engine(engine)

@pytest.fixture
def table(self, engine):
md = MetaData()
tbl = Table(
"foo",
md,
Column("bar", String, comment="column comment"),
comment="table comment",
)
md.create_all(bind=engine)

yield tbl

md.drop_all(bind=engine)

def test_table_comment_reflection(self, inspector: Inspector, table: Table):
tbl_name = table.name

comment = inspector.get_table_comment(tbl_name)

assert comment == {"text": "table comment"}
52 changes: 50 additions & 2 deletions src/databricks/sqlalchemy/test_local/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import pytest
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.schema import CreateTable, DropColumnComment, SetColumnComment
from sqlalchemy.schema import (
CreateTable,
DropColumnComment,
DropTableComment,
SetColumnComment,
SetTableComment,
)


class TestTableCommentDDL:
class DDLTestBase:
engine = create_engine(
"databricks://token:****@****?http_path=****&catalog=****&schema=****"
)

def compile(self, stmt):
return str(stmt.compile(bind=self.engine))


class TestColumnCommentDDL(DDLTestBase):
@pytest.fixture
def metadata(self) -> MetaData:
"""Assemble a metadata object with one table containing one column."""
Expand Down Expand Up @@ -45,3 +53,43 @@ def test_alter_table_drop_column_comment(self, column):
stmt = DropColumnComment(column)
output = self.compile(stmt)
assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''"


class TestTableCommentDDL(DDLTestBase):
@pytest.fixture
def metadata(self) -> MetaData:
"""Assemble a metadata object with one table containing one column."""
metadata = MetaData()

col1 = Column("foo", String)
col2 = Column("foo", String)
tbl_w_comment = Table("martin", metadata, col1, comment="foobar")
tbl_wo_comment = Table("prs", metadata, col2)

return metadata

@pytest.fixture
def table_with_comment(self, metadata) -> Table:
return metadata.tables.get("martin")

@pytest.fixture
def table_without_comment(self, metadata) -> Table:
return metadata.tables.get("prs")

def test_create_table_with_comment(self, table_with_comment):
stmt = CreateTable(table_with_comment)
output = self.compile(stmt)
assert "USING DELTA COMMENT 'foobar'" in output

def test_alter_table_add_comment(self, table_without_comment: Table):
table_without_comment.comment = "wireless mechanical keyboard"
stmt = SetTableComment(table_without_comment)
output = self.compile(stmt)

assert output == "COMMENT ON TABLE prs IS 'wireless mechanical keyboard'"

def test_alter_table_drop_comment(self, table_with_comment):
"""The syntax for COMMENT ON is here: https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-comment.html"""
stmt = DropTableComment(table_with_comment)
output = self.compile(stmt)
assert output == "COMMENT ON TABLE martin IS NULL"
Loading