Skip to content

Commit 439fe6a

Browse files
author
Jesse Whitehouse
committed
Merge branch 'sqlalchemy-staging' into table-comment
Signed-off-by: Jesse Whitehouse <[email protected]>
2 parents ba5ad75 + a7f4773 commit 439fe6a

File tree

9 files changed

+156
-9
lines changed

9 files changed

+156
-9
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Release History
22

3+
# 3.1.0 (TBD)
4+
5+
- Fix: `server_hostname` URIs that included `https://` would raise an exception
6+
37
## 3.0.1 (2023-12-01)
48

59
- Other: updated docstring comment about default parameterization approach (#287)

src/databricks/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
# https://packaging.python.org/guides/packaging-namespace-packages/#pkgutil-style-namespace-packages
2-
# This file should only contain the following line. Otherwise other sub-packages databricks.* namespace
3-
# may not be importable.
1+
# See: https://packaging.python.org/guides/packaging-namespace-packages/#pkgutil-style-namespace-packages
2+
#
3+
# This file must only contain the following line, or other packages in the databricks.* namespace
4+
# may not be importable. The contents of this file must be byte-for-byte equivalent across all packages.
5+
# If they are not, parallel package installation may lead to clobbered and invalid files.
6+
# Also see https://github.com/databricks/databricks-sdk-py/issues/343.
47
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

src/databricks/sql/thrift_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ def __init__(
141141
if kwargs.get("_connection_uri"):
142142
uri = kwargs.get("_connection_uri")
143143
elif server_hostname and http_path:
144-
uri = "https://{host}:{port}/{path}".format(
145-
host=server_hostname, port=port, path=http_path.lstrip("/")
144+
uri = "{host}:{port}/{path}".format(
145+
host=server_hostname.rstrip("/"), port=port, path=http_path.lstrip("/")
146146
)
147+
if not uri.startswith("https://"):
148+
uri = "https://" + uri
147149
else:
148150
raise ValueError("No valid connection settings.")
149151

src/databricks/sqlalchemy/_ddl.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,40 @@ def visit_identity_column(self, identity, **kw):
4545
)
4646
return text
4747

48+
def visit_set_column_comment(self, create, **kw):
49+
return "ALTER TABLE %s ALTER COLUMN %s COMMENT %s" % (
50+
self.preparer.format_table(create.element.table),
51+
self.preparer.format_column(create.element),
52+
self.sql_compiler.render_literal_value(
53+
create.element.comment, sqltypes.String()
54+
),
55+
)
56+
57+
def visit_drop_column_comment(self, create, **kw):
58+
return "ALTER TABLE %s ALTER COLUMN %s COMMENT ''" % (
59+
self.preparer.format_table(create.element.table),
60+
self.preparer.format_column(create.element),
61+
)
62+
4863
def get_column_specification(self, column, **kwargs):
49-
"""Currently we override this method only to emit a log message if a user attempts to set
50-
autoincrement=True on a column. See comments in test_suite.py. We may implement implicit
51-
IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect.
64+
"""
65+
Emit a log message if a user attempts to set autoincrement=True on a column.
66+
See comments in test_suite.py. We may implement implicit IDENTITY using this
67+
feature in the future, similar to the Microsoft SQL Server dialect.
5268
"""
5369
if column is column.table._autoincrement_column or column.autoincrement is True:
5470
logger.warning(
5571
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
5672
)
5773

58-
return super().get_column_specification(column, **kwargs)
74+
colspec = super().get_column_specification(column, **kwargs)
75+
if column.comment is not None:
76+
literal = self.sql_compiler.render_literal_value(
77+
column.comment, sqltypes.STRINGTYPE
78+
)
79+
colspec += " COMMENT " + literal
80+
81+
return colspec
5982

6083

6184
class DatabricksStatementCompiler(compiler.SQLCompiler):

src/databricks/sqlalchemy/_parse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
354354
"type": final_col_type,
355355
"nullable": bool(thrift_resp_row.NULLABLE),
356356
"default": thrift_resp_row.COLUMN_DEF,
357+
"comment": thrift_resp_row.REMARKS or None,
357358
}
358359

359360
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects

src/databricks/sqlalchemy/test_local/e2e/test_basic.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sqlalchemy.engine import Engine
1919
from sqlalchemy.engine.reflection import Inspector
2020
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
21+
from sqlalchemy.schema import DropColumnComment, SetColumnComment
2122
from sqlalchemy.types import BOOLEAN, DECIMAL, Date, DateTime, Integer, String
2223

2324
try:
@@ -188,6 +189,41 @@ def test_create_table_not_null(db_engine, metadata_obj: MetaData):
188189
metadata_obj.drop_all(db_engine)
189190

190191

192+
def test_column_comment(db_engine, metadata_obj: MetaData):
193+
table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))
194+
195+
column = Column("name", String(255), comment="some comment")
196+
SampleTable = Table(table_name, metadata_obj, column)
197+
198+
metadata_obj.create_all(db_engine)
199+
connection = db_engine.connect()
200+
201+
columns = db_engine.dialect.get_columns(
202+
connection=connection, table_name=table_name
203+
)
204+
205+
assert columns[0].get("comment") == "some comment"
206+
207+
column.comment = "other comment"
208+
connection.execute(SetColumnComment(column))
209+
210+
columns = db_engine.dialect.get_columns(
211+
connection=connection, table_name=table_name
212+
)
213+
214+
assert columns[0].get("comment") == "other comment"
215+
216+
connection.execute(DropColumnComment(column))
217+
218+
columns = db_engine.dialect.get_columns(
219+
connection=connection, table_name=table_name
220+
)
221+
222+
assert columns[0].get("comment") == ""
223+
224+
metadata_obj.drop_all(db_engine)
225+
226+
191227
def test_bulk_insert_with_core(db_engine, metadata_obj, session):
192228
import random
193229

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
from sqlalchemy import Column, MetaData, String, Table, create_engine
3+
from sqlalchemy.schema import CreateTable, DropColumnComment, SetColumnComment
4+
5+
6+
class TestTableCommentDDL:
7+
engine = create_engine(
8+
"databricks://token:****@****?http_path=****&catalog=****&schema=****"
9+
)
10+
11+
def compile(self, stmt):
12+
return str(stmt.compile(bind=self.engine))
13+
14+
@pytest.fixture
15+
def metadata(self) -> MetaData:
16+
"""Assemble a metadata object with one table containing one column."""
17+
metadata = MetaData()
18+
19+
column = Column("foo", String, comment="bar")
20+
table = Table("foobar", metadata, column)
21+
22+
return metadata
23+
24+
@pytest.fixture
25+
def table(self, metadata) -> Table:
26+
return metadata.tables.get("foobar")
27+
28+
@pytest.fixture
29+
def column(self, table) -> Column:
30+
return table.columns[0]
31+
32+
def test_create_table_with_column_comment(self, table):
33+
stmt = CreateTable(table)
34+
output = self.compile(stmt)
35+
36+
# output is a CREATE TABLE statement
37+
assert "foo STRING COMMENT 'bar'" in output
38+
39+
def test_alter_table_add_column_comment(self, column):
40+
stmt = SetColumnComment(column)
41+
output = self.compile(stmt)
42+
assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT 'bar'"
43+
44+
def test_alter_table_drop_column_comment(self, column):
45+
stmt = DropColumnComment(column)
46+
output = self.compile(stmt)
47+
assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''"

tests/unit/test_init_file.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import hashlib
2+
3+
4+
class TestInitFile:
5+
"""
6+
Micro test to confirm the contents of `databricks/__init__.py` does not change.
7+
8+
Also see https://github.com/databricks/databricks-sdk-py/issues/343#issuecomment-1866029118.
9+
"""
10+
11+
def test_init_file_contents(self):
12+
with open("src/databricks/__init__.py") as f:
13+
init_file_contents = f.read()
14+
15+
# This hash is the expected hash of the contents of `src/databricks/__init__.py`.
16+
# It must not change, or else parallel package installation may lead to clobbered and invalid files.
17+
expected_sha1 = "2772edbf52e517542acf8c039479c4b57b6ca2cd"
18+
actual_sha1 = hashlib.sha1(init_file_contents.encode("utf-8")).hexdigest()
19+
assert expected_sha1 == actual_sha1

tests/unit/test_thrift_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,18 @@ def test_port_and_host_are_respected(self, t_http_client_class):
212212
self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"],
213213
"https://hostname:123/path_value")
214214

215+
@patch("databricks.sql.auth.thrift_http_client.THttpClient")
216+
def test_host_with_https_does_not_duplicate(self, t_http_client_class):
217+
ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider())
218+
self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"],
219+
"https://hostname:123/path_value")
220+
221+
@patch("databricks.sql.auth.thrift_http_client.THttpClient")
222+
def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class):
223+
ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider())
224+
self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"],
225+
"https://hostname:123/path_value")
226+
215227
@patch("databricks.sql.auth.thrift_http_client.THttpClient")
216228
def test_socket_timeout_is_propagated(self, t_http_client_class):
217229
ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129)

0 commit comments

Comments
 (0)