diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7f..b5643f1d 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -1,4 +1,17 @@ -from typing import Any, List, Optional, Dict, Union +import re +from typing import Any, Dict, List, Optional, Union + +import sqlalchemy +from sqlalchemy import DDL, event +from sqlalchemy.engine import Connection, Engine, default, reflection +from sqlalchemy.engine.interfaces import ( + ReflectedColumn, + ReflectedForeignKeyConstraint, + ReflectedPrimaryKeyConstraint, + ReflectedTableComment, +) +from sqlalchemy.engine.reflection import ReflectionDefaults +from sqlalchemy.exc import DatabaseError, SQLAlchemyError import databricks.sqlalchemy._ddl as dialect_ddl_impl import databricks.sqlalchemy._types as dialect_type_impl @@ -8,24 +21,12 @@ _match_table_not_found_string, build_fk_dict, build_pk_dict, + get_comment_from_dte_output, get_fk_strings_from_dte_output, get_pk_strings_from_dte_output, - get_comment_from_dte_output, parse_column_info_from_tgetcolumnsresponse, ) -import sqlalchemy -from sqlalchemy import DDL, event -from sqlalchemy.engine import Connection, Engine, default, reflection -from sqlalchemy.engine.interfaces import ( - ReflectedForeignKeyConstraint, - ReflectedPrimaryKeyConstraint, - ReflectedColumn, - ReflectedTableComment, -) -from sqlalchemy.engine.reflection import ReflectionDefaults -from sqlalchemy.exc import DatabaseError, SQLAlchemyError - try: import alembic except ImportError: @@ -401,6 +402,21 @@ def get_table_comment( return ReflectionDefaults.table_comment() +SQLALCHEMY_TAG = f"sqlalchemy/{sqlalchemy.__version__}" +sqlalchemy_version_tag_pat = r"sqlalchemy/(\d+\.\d+\.\d+)" + + +def add_sqla_tag_if_not_present(val: Optional[str] = None): + if val is None or val == "": + output = SQLALCHEMY_TAG + elif re.search(sqlalchemy_version_tag_pat, val): + output = val + else: + output = f"{SQLALCHEMY_TAG} + {val}" + + return output + + @event.listens_for(Engine, "do_connect") def receive_do_connect(dialect, conn_rec, cargs, cparams): """Helpful for DS on traffic from clients using SQLAlchemy in particular""" @@ -411,18 +427,6 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): ua = cparams.get("_user_agent_entry", "") - def add_sqla_tag_if_not_present(val: str): - if not val: - output = "sqlalchemy" - - if val and "sqlalchemy" in val: - output = val - - else: - output = f"sqlalchemy + {val}" - - return output - cparams["_user_agent_entry"] = add_sqla_tag_if_not_present(ua) if sqlalchemy.__version__.startswith("1.3"): diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index ce0b5d89..b571d363 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -1,9 +1,11 @@ import datetime import decimal +import re from typing import Tuple, Union, List from unittest import skipIf import pytest +import sqlalchemy from sqlalchemy import ( Column, MetaData, @@ -20,6 +22,12 @@ from sqlalchemy.schema import DropColumnComment, SetColumnComment from sqlalchemy.types import BOOLEAN, DECIMAL, Date, Integer, String +from databricks.sqlalchemy.base import ( + SQLALCHEMY_TAG, + add_sqla_tag_if_not_present, + sqlalchemy_version_tag_pat, +) + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -120,20 +128,6 @@ def test_can_connect(db_engine): assert len(result) == 1 -def test_connect_args(db_engine): - """Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI - - This will most commonly happen when partners supply a user agent entry - """ - - conn = db_engine.connect() - connection_headers = conn.connection.thrift_backend._transport._headers - user_agent = connection_headers["User-Agent"] - - expected = f"(sqlalchemy + {USER_AGENT_TOKEN})" - assert expected in user_agent - - @pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4") @pytest.mark.skip( reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass." @@ -448,25 +442,80 @@ def test_has_table_across_schemas( conn.execute(text("DROP TABLE test_has_table;")) -def test_user_agent_adjustment(db_engine): - # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent - # https://github.com/databricks/databricks-sql-python/issues/192 - c1 = db_engine.connect() - c2 = db_engine.connect() +class TestUserAgent: + @pytest.fixture(scope="class") + def expected_sqlalchemy_tag(self): + import sqlalchemy + + user_agent_tag = f"sqlalchemy/{sqlalchemy.__version__}" + return user_agent_tag - def get_conn_user_agent(conn): + def get_conn_user_agent(self, conn): return conn.connection.dbapi_connection.thrift_backend._transport._headers.get( "User-Agent" ) - ua1 = get_conn_user_agent(c1) - ua2 = get_conn_user_agent(c2) - same_ua = ua1 == ua2 + def test_user_agent_adjustment(self, db_engine): + # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent + # https://github.com/databricks/databricks-sql-python/issues/192 + c1 = db_engine.connect() + c2 = db_engine.connect() + + ua1 = self.get_conn_user_agent(c1) + ua2 = self.get_conn_user_agent(c2) + same_ua = ua1 == ua2 + + c1.close() + c2.close() - c1.close() - c2.close() + assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" - assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" + def test_sqlalchemy_user_agent_includes_version(self, db_engine): + """So that we know when we can safely deprecate support for sqlalchemy 1.x""" + + version_str = sqlalchemy.__version__ + c = db_engine.connect() + ua = self.get_conn_user_agent(c) + + assert version_str in ua + + def test_user_supplied_string(self, db_engine): + """Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI + + This will most commonly happen when partners supply a user agent entry + """ + + conn = db_engine.connect() + connection_headers = conn.connection.thrift_backend._transport._headers + user_agent = connection_headers["User-Agent"] + + assert USER_AGENT_TOKEN in user_agent + + @pytest.mark.parametrize( + "input, expected", + ( + (None, "{}"), + ("", "{}"), + ("sqlalchemy connection", "{} + sqlalchemy connection"), + ( + "reusable dialect compliance tests", + "{} + reusable dialect compliance tests", + ), + ), + ) + def test_user_agent_insertion_behavior( + self, input: Union[str, None], expected: str, expected_sqlalchemy_tag: str + ): + assert add_sqla_tag_if_not_present(input) == expected.format( + expected_sqlalchemy_tag + ) + + @pytest.mark.parametrize( + "input", + ("sqlalchemy/1.4.0", "sqlalchemy/1.3.0", "sqlalchemy/2.0.0", SQLALCHEMY_TAG), + ) + def test_sqlalchemy_tag_regexes_properly(self, input): + assert re.search(sqlalchemy_version_tag_pat, input) @pytest.fixture