Skip to content

[sqlalchemy] Add table and column comment support #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# 3.1.0 (TBD)

- SQLAlchemy: Added support for table and column comments (thanks @cbornet!)
- Fix: `server_hostname` URIs that included `https://` would raise an exception

## 3.0.1 (2023-12-01)
Expand Down
43 changes: 36 additions & 7 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, sqltypes
import logging

logger = logging.getLogger(__name__)
Expand All @@ -16,7 +16,13 @@ def __init__(self, dialect):

class DatabricksDDLCompiler(compiler.DDLCompiler):
def post_create_table(self, table):
return " USING DELTA"
post = " USING DELTA"
if table.comment:
comment = self.sql_compiler.render_literal_value(
table.comment, sqltypes.String()
)
post += " COMMENT " + comment
return post

def visit_unique_constraint(self, constraint, **kw):
logger.warning("Databricks does not support unique constraints")
Expand All @@ -39,17 +45,40 @@ def visit_identity_column(self, identity, **kw):
)
return text

def visit_set_column_comment(self, create, **kw):
return "ALTER TABLE %s ALTER COLUMN %s COMMENT %s" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
self.sql_compiler.render_literal_value(
create.element.comment, sqltypes.String()
),
)

def visit_drop_column_comment(self, create, **kw):
return "ALTER TABLE %s ALTER COLUMN %s COMMENT ''" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
)

def get_column_specification(self, column, **kwargs):
"""Currently we override this method only to emit a log message if a user attempts to set
autoincrement=True on a column. See comments in test_suite.py. We may implement implicit
IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect.
"""
Emit a log message if a user attempts to set autoincrement=True on a column.
See comments in test_suite.py. We may implement implicit IDENTITY using this
feature in the future, similar to the Microsoft SQL Server dialect.
"""
if column is column.table._autoincrement_column or column.autoincrement is True:
logger.warn(
logger.warning(
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
)

return super().get_column_specification(column, **kwargs)
colspec = super().get_column_specification(column, **kwargs)
if column.comment is not None:
literal = self.sql_compiler.render_literal_value(
column.comment, sqltypes.STRINGTYPE
)
colspec += " COMMENT " + literal

return colspec


class DatabricksStatementCompiler(compiler.SQLCompiler):
Expand Down
24 changes: 24 additions & 0 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
return output_rows


def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name`
value contains the match argument.
"""

output_rows = []

for row_dict in dte_output:
if match in row_dict["col_name"]:
output_rows.append(row_dict)

return output_rows


def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
one dictionary per defined constraint
Expand All @@ -275,6 +289,15 @@ def get_pk_strings_from_dte_output(
return output


def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]:
"""Returns the value of the first "Comment" col_name data in dte_output"""
output = match_dte_rows_by_key(dte_output, "Comment")
if not output:
return None
else:
return output[0]["data_type"]


# The keys of this dictionary are the values we expect to see in a
# TGetColumnsRequest's .TYPE_NAME attribute.
# These are enumerated in ttypes.py as class TTypeId.
Expand Down Expand Up @@ -354,6 +377,7 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
"type": final_col_type,
"nullable": bool(thrift_resp_row.NULLABLE),
"default": thrift_resp_row.COLUMN_DEF,
"comment": thrift_resp_row.REMARKS or None,
}

# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
Expand Down
36 changes: 30 additions & 6 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
Expand All @@ -11,19 +10,20 @@
build_pk_dict,
get_fk_strings_from_dte_output,
get_pk_strings_from_dte_output,
get_comment_from_dte_output,
parse_column_info_from_tgetcolumnsresponse,
)

import sqlalchemy
from sqlalchemy import DDL, event
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.reflection import ObjectKind
from sqlalchemy.engine.interfaces import (
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedColumn,
TableKey,
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

try:
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
views_result = self.get_view_names(connection=connection, schema=schema)

# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
# Potential optimisation: rewrite this to instead query informtation_schema
# Potential optimisation: rewrite this to instead query information_schema
tables_minus_views = [
row.tableName for row in tables_result if row.tableName not in views_result
]
Expand Down Expand Up @@ -328,7 +328,7 @@ def get_materialized_view_names(
def get_temp_view_names(
self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
"""A wrapper around get_view_names that fetches only the names of temporary views"""
return self.get_view_names(connection, schema, only_temp=True)

def do_rollback(self, dbapi_connection):
Expand Down Expand Up @@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw):
schema_list = [row[0] for row in result]
return schema_list

@reflection.cache
def get_table_comment(
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> ReflectedTableComment:
result = self._describe_table_extended(
connection=connection,
table_name=table_name,
schema_name=schema,
)

if result is None:
return ReflectionDefaults.table_comment()

comment = get_comment_from_dte_output(result)

if comment:
return dict(text=comment)
else:
return ReflectionDefaults.table_comment()


@event.listens_for(Engine, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/sqlalchemy/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def table_reflection(self):
"""target database has general support for table reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection(self):
"""Indicates if the database support table comment reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection_full_unicode(self):
"""Indicates if the database support table comment reflection in the
full unicode range, including emoji etc.
"""
return sqlalchemy.testing.exclusions.open()

@property
def temp_table_reflection(self):
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.
Expand Down
48 changes: 0 additions & 48 deletions src/databricks/sqlalchemy/test/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
ComponentReflectionTest,
ComponentReflectionTestExtra,
CTETest,
FutureTableDDLTest,
InsertBehaviorTest,
TableDDLTest,
)
from sqlalchemy.testing.suite import (
ArrayTest,
Expand Down Expand Up @@ -53,7 +51,6 @@ class FutureFeature(Enum):
PROVISION = "event-driven engine configuration"
REGEXP = "_visit_regexp"
SANE_ROWCOUNT = "sane_rowcount support"
TBL_COMMENTS = "table comment reflection"
TBL_OPTS = "get_table_options method"
TEST_DESIGN = "required test-fixture overrides"
TUPLE_LITERAL = "tuple-like IN markers completely"
Expand Down Expand Up @@ -251,36 +248,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
pass


class FutureTableDDLTest(FutureTableDDLTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class TableDDLTest(TableDDLTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class ComponentReflectionTest(ComponentReflectionTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_multi_table_comment(self):
"""There are 84 permutations of this test that are skipped."""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
def test_multi_get_table_options_tables(self):
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
Expand All @@ -302,22 +270,6 @@ def test_get_multi_pk_constraint(self):
def test_get_multi_check_constraints(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments_with_schema(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode_full(self):
pass


class ComponentReflectionTestExtra(ComponentReflectionTestExtra):
@pytest.mark.skip(render_future_feature(FutureFeature.CHECK))
Expand Down
Loading