From 3db45630f97247ae68f67a68af50f58466f3b219 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Thu, 25 Jan 2024 19:37:24 -0500 Subject: [PATCH 1/7] Encode sqlalchemy version in the user agent and reorganize tests to reuse code where available Signed-off-by: Jesse Whitehouse --- src/databricks/sqlalchemy/base.py | 7 ++-- .../sqlalchemy/test_local/e2e/test_basic.py | 37 +++++++++++++------ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 40af61fe..93e4595c 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -411,14 +411,15 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): ua = cparams.get("_user_agent_entry", "") def add_sqla_tag_if_not_present(val: str): + tag = f"sqlalchemy=={sqlalchemy.__version__}" if not val: - output = "sqlalchemy" + output = tag - if val and "sqlalchemy" in val: + if val and tag in val: output = val else: - output = f"sqlalchemy + {val}" + output = f"{tag} + {val}" return output diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index ec54c282..79405530 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -449,26 +449,39 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): conn.execute(text("DROP TABLE test_has_table;")) -def test_user_agent_adjustment(db_engine): - # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent - # https://github.com/databricks/databricks-sql-python/issues/192 - c1 = db_engine.connect() - c2 = db_engine.connect() +class TestSQLAlchemyUserAgent: - def get_conn_user_agent(conn): + def get_conn_user_agent(self, conn): return conn.connection.dbapi_connection.thrift_backend._transport._headers.get( "User-Agent" ) + + def test_user_agent_adjustment(self, db_engine): + # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent + # https://github.com/databricks/databricks-sql-python/issues/192 + c1 = db_engine.connect() + c2 = db_engine.connect() - ua1 = get_conn_user_agent(c1) - ua2 = get_conn_user_agent(c2) - same_ua = ua1 == ua2 - c1.close() - c2.close() + ua1 = self.get_conn_user_agent(c1) + ua2 = self.get_conn_user_agent(c2) + same_ua = ua1 == ua2 - assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" + c1.close() + c2.close() + assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" + + def test_sqlalchemy_user_agent_includes_version(self, db_engine): + """So that we know when we can safely deprecate support for sqlalchemy 1.x + """ + + import sqlalchemy + version_str = sqlalchemy.__version__ + c = db_engine.connect() + ua = self.get_conn_user_agent(c) + + assert version_str in ua @pytest.fixture def sample_table(metadata_obj: MetaData, db_engine: Engine): From a6c78d12e313d5343ea90c7f62834e574dd033b5 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Thu, 25 Jan 2024 19:39:54 -0500 Subject: [PATCH 2/7] Move connect args test into TestUserAgent Signed-off-by: Jesse Whitehouse --- .../sqlalchemy/test_local/e2e/test_basic.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index 79405530..69b9f326 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -120,20 +120,6 @@ def test_can_connect(db_engine): assert len(result) == 1 -def test_connect_args(db_engine): - """Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI - - This will most commonly happen when partners supply a user agent entry - """ - - conn = db_engine.connect() - connection_headers = conn.connection.thrift_backend._transport._headers - user_agent = connection_headers["User-Agent"] - - expected = f"(sqlalchemy + {USER_AGENT_TOKEN})" - assert expected in user_agent - - @pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4") @pytest.mark.skip( reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass." @@ -483,6 +469,19 @@ def test_sqlalchemy_user_agent_includes_version(self, db_engine): assert version_str in ua + def test_user_supplied_string(self, db_engine): + """Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI + + This will most commonly happen when partners supply a user agent entry + """ + + conn = db_engine.connect() + connection_headers = conn.connection.thrift_backend._transport._headers + user_agent = connection_headers["User-Agent"] + + assert USER_AGENT_TOKEN in user_agent + + @pytest.fixture def sample_table(metadata_obj: MetaData, db_engine: Engine): """This fixture creates a sample table and cleans it up after the test is complete.""" From 36a8aca813da6f8826f1b2099a9bf8c2742bd584 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Thu, 25 Jan 2024 19:40:30 -0500 Subject: [PATCH 3/7] Format the file with black Signed-off-by: Jesse Whitehouse --- src/databricks/sqlalchemy/test_local/e2e/test_basic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index 69b9f326..ee62528e 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -435,20 +435,18 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): conn.execute(text("DROP TABLE test_has_table;")) -class TestSQLAlchemyUserAgent: - +class TestUserAgent: def get_conn_user_agent(self, conn): return conn.connection.dbapi_connection.thrift_backend._transport._headers.get( "User-Agent" ) - + def test_user_agent_adjustment(self, db_engine): # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent # https://github.com/databricks/databricks-sql-python/issues/192 c1 = db_engine.connect() c2 = db_engine.connect() - ua1 = self.get_conn_user_agent(c1) ua2 = self.get_conn_user_agent(c2) same_ua = ua1 == ua2 @@ -459,10 +457,10 @@ def test_user_agent_adjustment(self, db_engine): assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" def test_sqlalchemy_user_agent_includes_version(self, db_engine): - """So that we know when we can safely deprecate support for sqlalchemy 1.x - """ + """So that we know when we can safely deprecate support for sqlalchemy 1.x""" import sqlalchemy + version_str = sqlalchemy.__version__ c = db_engine.connect() ua = self.get_conn_user_agent(c) From 8905b6436219c4e50ecfb87a073e0600b6ed3bd8 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Thu, 25 Jan 2024 20:10:06 -0500 Subject: [PATCH 4/7] Use preferred separator in user agent Signed-off-by: Jesse Whitehouse --- src/databricks/sqlalchemy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 93e4595c..c984efac 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -411,7 +411,7 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): ua = cparams.get("_user_agent_entry", "") def add_sqla_tag_if_not_present(val: str): - tag = f"sqlalchemy=={sqlalchemy.__version__}" + tag = f"sqlalchemy/{sqlalchemy.__version__}" if not val: output = tag From dac30369607a51424e9ff4b04f45dcd7d411259c Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Fri, 26 Jan 2024 11:00:05 -0500 Subject: [PATCH 5/7] Add further unit tests for sqlalchemy tag insertion plus extract format function outside of its existing declaration so I can unit test it Signed-off-by: Jesse Whitehouse --- src/databricks/sqlalchemy/base.py | 57 ++++++++++--------- .../sqlalchemy/test_local/e2e/test_basic.py | 47 +++++++++++++-- 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index c984efac..54cf4038 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -1,4 +1,17 @@ -from typing import Any, List, Optional, Dict, Union +import re +from typing import Any, Dict, List, Optional, Union + +import sqlalchemy +from sqlalchemy import DDL, event +from sqlalchemy.engine import Connection, Engine, default, reflection +from sqlalchemy.engine.interfaces import ( + ReflectedColumn, + ReflectedForeignKeyConstraint, + ReflectedPrimaryKeyConstraint, + ReflectedTableComment, +) +from sqlalchemy.engine.reflection import ReflectionDefaults +from sqlalchemy.exc import DatabaseError, SQLAlchemyError import databricks.sqlalchemy._ddl as dialect_ddl_impl import databricks.sqlalchemy._types as dialect_type_impl @@ -8,24 +21,12 @@ _match_table_not_found_string, build_fk_dict, build_pk_dict, + get_comment_from_dte_output, get_fk_strings_from_dte_output, get_pk_strings_from_dte_output, - get_comment_from_dte_output, parse_column_info_from_tgetcolumnsresponse, ) -import sqlalchemy -from sqlalchemy import DDL, event -from sqlalchemy.engine import Connection, Engine, default, reflection -from sqlalchemy.engine.interfaces import ( - ReflectedForeignKeyConstraint, - ReflectedPrimaryKeyConstraint, - ReflectedColumn, - ReflectedTableComment, -) -from sqlalchemy.engine.reflection import ReflectionDefaults -from sqlalchemy.exc import DatabaseError, SQLAlchemyError - try: import alembic except ImportError: @@ -400,6 +401,21 @@ def get_table_comment( return ReflectionDefaults.table_comment() +SQLALCHEMY_TAG = f"sqlalchemy/{sqlalchemy.__version__}" +sqlalchemy_version_tag_pat = r"sqlalchemy/(\d+\.\d+\.\d+)" + + +def add_sqla_tag_if_not_present(val: str): + if val is None or val == "": + output = SQLALCHEMY_TAG + elif re.search(sqlalchemy_version_tag_pat, val): + output = val + else: + output = f"{SQLALCHEMY_TAG} + {val}" + + return output + + @event.listens_for(Engine, "do_connect") def receive_do_connect(dialect, conn_rec, cargs, cparams): """Helpful for DS on traffic from clients using SQLAlchemy in particular""" @@ -410,19 +426,6 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): ua = cparams.get("_user_agent_entry", "") - def add_sqla_tag_if_not_present(val: str): - tag = f"sqlalchemy/{sqlalchemy.__version__}" - if not val: - output = tag - - if val and tag in val: - output = val - - else: - output = f"{tag} + {val}" - - return output - cparams["_user_agent_entry"] = add_sqla_tag_if_not_present(ua) if sqlalchemy.__version__.startswith("1.3"): diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index ee62528e..a813a6d7 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -1,10 +1,12 @@ import datetime import decimal import os -from typing import Tuple, Union, List +import re +from typing import List, Tuple, Union from unittest import skipIf import pytest +import sqlalchemy from sqlalchemy import ( Column, MetaData, @@ -19,7 +21,13 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.schema import DropColumnComment, SetColumnComment -from sqlalchemy.types import BOOLEAN, DECIMAL, Date, DateTime, Integer, String +from sqlalchemy.types import BOOLEAN, DECIMAL, Date, Integer, String + +from databricks.sqlalchemy.base import ( + SQLALCHEMY_TAG, + add_sqla_tag_if_not_present, + sqlalchemy_version_tag_pat, +) try: from sqlalchemy.orm import declarative_base @@ -436,6 +444,13 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): class TestUserAgent: + @pytest.fixture(scope="class") + def expected_sqlalchemy_tag(self): + import sqlalchemy + + user_agent_tag = f"sqlalchemy/{sqlalchemy.__version__}" + return user_agent_tag + def get_conn_user_agent(self, conn): return conn.connection.dbapi_connection.thrift_backend._transport._headers.get( "User-Agent" @@ -459,8 +474,6 @@ def test_user_agent_adjustment(self, db_engine): def test_sqlalchemy_user_agent_includes_version(self, db_engine): """So that we know when we can safely deprecate support for sqlalchemy 1.x""" - import sqlalchemy - version_str = sqlalchemy.__version__ c = db_engine.connect() ua = self.get_conn_user_agent(c) @@ -479,6 +492,32 @@ def test_user_supplied_string(self, db_engine): assert USER_AGENT_TOKEN in user_agent + @pytest.mark.parametrize( + "input, expected", + ( + (None, "{}"), + ("", "{}"), + ("sqlalchemy connection", "{} + sqlalchemy connection"), + ( + "reusable dialect compliance tests", + "{} + reusable dialect compliance tests", + ), + ), + ) + def test_user_agent_insertion_behavior( + self, input: Union[str, None], expected: str, expected_sqlalchemy_tag: str + ): + assert add_sqla_tag_if_not_present(input) == expected.format( + expected_sqlalchemy_tag + ) + + @pytest.mark.parametrize( + "input", + ("sqlalchemy/1.4.0", "sqlalchemy/1.3.0", "sqlalchemy/2.0.0", SQLALCHEMY_TAG), + ) + def test_sqlalchemy_tag_regexes_properly(self, input): + assert re.search(sqlalchemy_version_tag_pat, input) + @pytest.fixture def sample_table(metadata_obj: MetaData, db_engine: Engine): From 200490977858e4d4b6b9588a7780b116e5ca1ec8 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Fri, 26 Jan 2024 11:47:18 -0500 Subject: [PATCH 6/7] For some reason the user-provided user agent is no longer appearing in query history. Need to investigate this before merging. Signed-off-by: Jesse Whitehouse From 9466bcfa3f9d9c1dd08fd31cd8679d123ca852d0 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 30 May 2024 18:26:37 +0300 Subject: [PATCH 7/7] Fix mypy errors Signed-off-by: Levko Kravets --- src/databricks/sqlalchemy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 31ec0f89..b5643f1d 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -406,7 +406,7 @@ def get_table_comment( sqlalchemy_version_tag_pat = r"sqlalchemy/(\d+\.\d+\.\d+)" -def add_sqla_tag_if_not_present(val: str): +def add_sqla_tag_if_not_present(val: Optional[str] = None): if val is None or val == "": output = SQLALCHEMY_TAG elif re.search(sqlalchemy_version_tag_pat, val):