diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index 7aefb034..e9fd9f2b 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -1,5 +1,5 @@ import re -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, sqltypes import logging logger = logging.getLogger(__name__) @@ -39,17 +39,40 @@ def visit_identity_column(self, identity, **kw): ) return text + def visit_set_column_comment(self, create, **kw): + return "ALTER TABLE %s ALTER COLUMN %s COMMENT %s" % ( + self.preparer.format_table(create.element.table), + self.preparer.format_column(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_column_comment(self, create, **kw): + return "ALTER TABLE %s ALTER COLUMN %s COMMENT ''" % ( + self.preparer.format_table(create.element.table), + self.preparer.format_column(create.element), + ) + def get_column_specification(self, column, **kwargs): - """Currently we override this method only to emit a log message if a user attempts to set - autoincrement=True on a column. See comments in test_suite.py. We may implement implicit - IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect. + """ + Emit a log message if a user attempts to set autoincrement=True on a column. + See comments in test_suite.py. We may implement implicit IDENTITY using this + feature in the future, similar to the Microsoft SQL Server dialect. """ if column is column.table._autoincrement_column or column.autoincrement is True: logger.warn( "Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead." ) - return super().get_column_specification(column, **kwargs) + colspec = super().get_column_specification(column, **kwargs) + if column.comment is not None: + literal = self.sql_compiler.render_literal_value( + column.comment, sqltypes.STRINGTYPE + ) + colspec += " COMMENT " + literal + + return colspec class DatabricksStatementCompiler(compiler.SQLCompiler): diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 42c4774d..1e11cd70 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -354,6 +354,7 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu "type": final_col_type, "nullable": bool(thrift_resp_row.NULLABLE), "default": thrift_resp_row.COLUMN_DEF, + "comment": thrift_resp_row.REMARKS, } # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index 3696356c..0c47f3e7 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -18,6 +18,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column +from sqlalchemy.schema import DropColumnComment, SetColumnComment from sqlalchemy.types import BOOLEAN, DECIMAL, Date, DateTime, Integer, String try: @@ -188,6 +189,41 @@ def test_create_table_not_null(db_engine, metadata_obj: MetaData): metadata_obj.drop_all(db_engine) +def test_column_comment(db_engine, metadata_obj: MetaData): + table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) + + column = Column("name", String(255), comment="some comment") + SampleTable = Table(table_name, metadata_obj, column) + + metadata_obj.create_all(db_engine) + connection = db_engine.connect() + + columns = db_engine.dialect.get_columns( + connection=connection, table_name=table_name + ) + + assert columns[0].get("comment") == "some comment" + + column.comment = "other comment" + connection.execute(SetColumnComment(column)) + + columns = db_engine.dialect.get_columns( + connection=connection, table_name=table_name + ) + + assert columns[0].get("comment") == "other comment" + + connection.execute(DropColumnComment(column)) + + columns = db_engine.dialect.get_columns( + connection=connection, table_name=table_name + ) + + assert columns[0].get("comment") == "" + + metadata_obj.drop_all(db_engine) + + def test_bulk_insert_with_core(db_engine, metadata_obj, session): import random diff --git a/src/databricks/sqlalchemy/test_local/test_ddl.py b/src/databricks/sqlalchemy/test_local/test_ddl.py new file mode 100644 index 00000000..eb8e7083 --- /dev/null +++ b/src/databricks/sqlalchemy/test_local/test_ddl.py @@ -0,0 +1,47 @@ +import pytest +from sqlalchemy import Column, MetaData, String, Table, create_engine +from sqlalchemy.schema import CreateTable, DropColumnComment, SetColumnComment + + +class TestTableCommentDDL: + engine = create_engine( + "databricks://token:****@****?http_path=****&catalog=****&schema=****" + ) + + def compile(self, stmt): + return str(stmt.compile(bind=self.engine)) + + @pytest.fixture + def metadata(self) -> MetaData: + """Assemble a metadata object with one table containing one column.""" + metadata = MetaData() + + column = Column("foo", String, comment="bar") + table = Table("foobar", metadata, column) + + return metadata + + @pytest.fixture + def table(self, metadata) -> Table: + return metadata.tables.get("foobar") + + @pytest.fixture + def column(self, table) -> Column: + return table.columns[0] + + def test_create_table_with_column_comment(self, table): + stmt = CreateTable(table) + output = self.compile(stmt) + + # output is a CREATE TABLE statement + assert "foo STRING COMMENT 'bar'" in output + + def test_alter_table_add_column_comment(self, column): + stmt = SetColumnComment(column) + output = self.compile(stmt) + assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT 'bar'" + + def test_alter_table_drop_column_comment(self, column): + stmt = DropColumnComment(column) + output = self.compile(stmt) + assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''"