Skip to content

[Issue 171] Improve sqlalchemy dialect backward compatibility with 1.3.24 #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 11, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 2.7.x (Unreleased)

- Add support for Cloud Fetch
- Fix: Revised SQLAlchemy dialect and examples for compatibility with SQLAlchemy==1.3.x

## 2.7.0 (2023-06-26)

Expand Down
37 changes: 29 additions & 8 deletions examples/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@
"""

import os
from sqlalchemy.orm import declarative_base, Session
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import Column, String, Integer, BOOLEAN, create_engine, select

try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

host = os.getenv("DATABRICKS_SERVER_HOSTNAME")
http_path = os.getenv("DATABRICKS_HTTP_PATH")
access_token = os.getenv("DATABRICKS_TOKEN")
Expand All @@ -59,10 +65,20 @@
"_user_agent_entry": "PySQL Example Script",
}

engine = create_engine(
f"databricks://token:{access_token}@{host}?http_path={http_path}&catalog={catalog}&schema={schema}",
connect_args=extra_connect_args,
)
if sqlalchemy.__version__.startswith("1.3"):
# SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string
# Pass these in as connect_args instead

conn_string = f"databricks://token:{access_token}@{host}"
connect_args = dict(catalog=catalog, schema=schema, http_path=http_path)
all_connect_args = {**extra_connect_args, **connect_args}
engine = create_engine(conn_string, connect_args=all_connect_args)
else:
engine = create_engine(
f"databricks://token:{access_token}@{host}?http_path={http_path}&catalog={catalog}&schema={schema}",
connect_args=extra_connect_args,
)

session = Session(bind=engine)
base = declarative_base(bind=engine)

Expand All @@ -73,7 +89,7 @@ class SampleObject(base):

name = Column(String(255), primary_key=True)
episodes = Column(Integer)
some_bool = Column(BOOLEAN)
some_bool = Column(BOOLEAN(create_constraint=False))


base.metadata.create_all()
Expand All @@ -86,9 +102,14 @@ class SampleObject(base):

session.commit()

stmt = select(SampleObject).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
# SQLAlchemy 1.3 has slightly different methods
if sqlalchemy.__version__.startswith("1.3"):
stmt = select([SampleObject]).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
output = [i for i in session.execute(stmt)]
else:
stmt = select(SampleObject).where(SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]))
output = [i for i in session.scalars(stmt)]

output = [i for i in session.scalars(stmt)]
assert len(output) == 2

base.metadata.drop_all()
38 changes: 28 additions & 10 deletions src/databricks/sqlalchemy/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import decimal, re, datetime
from dateutil.parser import parse

import sqlalchemy
from sqlalchemy import types, processors, event
from sqlalchemy.engine import default, Engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
from sqlalchemy.engine import reflection

from databricks import sql
Expand Down Expand Up @@ -153,9 +154,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
"date": DatabricksDate,
}

with self.get_driver_connection(
connection
)._dbapi_connection.dbapi_connection.cursor() as cur:
with self.get_connection_cursor(connection) as cur:
resp = cur.columns(
catalog_name=self.catalog,
schema_name=schema or self.schema,
Expand Down Expand Up @@ -244,9 +243,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw):

def get_table_names(self, connection, schema=None, **kwargs):
TABLE_NAME = 1
with self.get_driver_connection(
connection
)._dbapi_connection.dbapi_connection.cursor() as cur:
with self.get_connection_cursor(connection) as cur:
sql_str = "SHOW TABLES FROM {}".format(
".".join([self.catalog, schema or self.schema])
)
Expand All @@ -257,9 +254,7 @@ def get_table_names(self, connection, schema=None, **kwargs):

def get_view_names(self, connection, schema=None, **kwargs):
VIEW_NAME = 1
with self.get_driver_connection(
connection
)._dbapi_connection.dbapi_connection.cursor() as cur:
with self.get_connection_cursor(connection) as cur:
sql_str = "SHOW VIEWS FROM {}".format(
".".join([self.catalog, schema or self.schema])
)
Expand Down Expand Up @@ -292,6 +287,19 @@ def has_table(self, connection, table_name, schema=None, **kwargs) -> bool:
else:
raise e

def get_connection_cursor(self, connection):
"""Added for backwards compatibility with 1.3.x"""
if hasattr(connection, "_dbapi_connection"):
return connection._dbapi_connection.dbapi_connection.cursor()
elif hasattr(connection, "raw_connection"):
return connection.raw_connection().cursor()
elif hasattr(connection, "connection"):
return connection.connection.cursor()

raise SQLAlchemyError(
"Databricks dialect can't obtain a cursor context manager from the dbapi"
)

@reflection.cache
def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
Expand All @@ -314,3 +322,13 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams):
new_user_agent = "sqlalchemy"

cparams["_user_agent_entry"] = new_user_agent

if sqlalchemy.__version__.startswith("1.3"):
# SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string
# These should be passed in as connect_args when building the Engine

if "schema" in cparams:
dialect.schema = cparams["schema"]

if "catalog" in cparams:
dialect.catalog = cparams["catalog"]
128 changes: 108 additions & 20 deletions tests/e2e/sqlalchemy/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,81 @@
import pytest
from unittest import skipIf
from sqlalchemy import create_engine, select, insert, Column, MetaData, Table
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy.orm import Session
from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String, DECIMAL, Date
from sqlalchemy.engine import Engine

from typing import Tuple

try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base


USER_AGENT_TOKEN = "PySQL e2e Tests"


@pytest.fixture
def db_engine():
def sqlalchemy_1_3():
import sqlalchemy

return sqlalchemy.__version__.startswith("1.3")


def version_agnostic_select(object_to_select, *args, **kwargs):
"""
SQLAlchemy==1.3.x requires arguments to select() to be a Python list

https://docs.sqlalchemy.org/en/20/changelog/migration_14.html#orm-query-is-internally-unified-with-select-update-delete-2-0-style-execution-available
"""

if sqlalchemy_1_3():
return select([object_to_select], *args, **kwargs)
else:
return select(object_to_select, *args, **kwargs)


def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str, dict]:

HOST = os.environ.get("host")
HTTP_PATH = os.environ.get("http_path")
ACCESS_TOKEN = os.environ.get("access_token")
CATALOG = os.environ.get("catalog")
SCHEMA = os.environ.get("schema")
CATALOG = catalog or os.environ.get("catalog")
SCHEMA = schema or os.environ.get("schema")

ua_connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}

if sqlalchemy_1_3():
conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}"
connect_args = {
**ua_connect_args,
"http_path": HTTP_PATH,
"server_hostname": HOST,
"catalog": CATALOG,
"schema": SCHEMA,
}

return conn_string, connect_args
else:
return (
f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}",
ua_connect_args,
)


@pytest.fixture
def db_engine() -> Engine:
conn_string, connect_args = version_agnostic_connect_arguments()
return create_engine(conn_string, connect_args=connect_args)

connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}

engine = create_engine(
f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}",
connect_args=connect_args,
@pytest.fixture
def samples_engine() -> Engine:

conn_string, connect_args = version_agnostic_connect_arguments(
catalog="samples", schema="nyctaxi"
)
return engine
return create_engine(conn_string, connect_args=connect_args)


@pytest.fixture()
Expand Down Expand Up @@ -62,6 +114,7 @@ def test_connect_args(db_engine):
assert expected in user_agent


@pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4")
def test_pandas_upload(db_engine, metadata_obj):

import pandas as pd
Expand All @@ -86,7 +139,7 @@ def test_pandas_upload(db_engine, metadata_obj):
db_engine.execute("DROP TABLE mock_data")


def test_create_table_not_null(db_engine, metadata_obj):
def test_create_table_not_null(db_engine, metadata_obj: MetaData):

table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))

Expand All @@ -95,7 +148,7 @@ def test_create_table_not_null(db_engine, metadata_obj):
metadata_obj,
Column("name", String(255)),
Column("episodes", Integer),
Column("some_bool", BOOLEAN, nullable=False),
Column("some_bool", BOOLEAN(create_constraint=False), nullable=False),
)

metadata_obj.create_all()
Expand Down Expand Up @@ -135,7 +188,7 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
metadata_obj.create_all()
db_engine.execute(insert(SampleTable).values(rows))

rows = db_engine.execute(select(SampleTable)).fetchall()
rows = db_engine.execute(version_agnostic_select(SampleTable)).fetchall()

assert len(rows) == num_to_insert

Expand All @@ -148,7 +201,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
metadata_obj,
Column("name", String(255)),
Column("episodes", Integer),
Column("some_bool", BOOLEAN),
Column("some_bool", BOOLEAN(create_constraint=False)),
Column("dollars", DECIMAL(10, 2)),
)

Expand All @@ -161,7 +214,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
with db_engine.connect() as conn:
conn.execute(insert_stmt)

select_stmt = select(SampleTable)
select_stmt = version_agnostic_select(SampleTable)
resp = db_engine.execute(select_stmt)

result = resp.fetchall()
Expand All @@ -187,7 +240,7 @@ class SampleObject(base):

name = Column(String(255), primary_key=True)
episodes = Column(Integer)
some_bool = Column(BOOLEAN)
some_bool = Column(BOOLEAN(create_constraint=False))

base.metadata.create_all()

Expand All @@ -197,11 +250,15 @@ class SampleObject(base):
session.add(sample_object_2)
session.commit()

stmt = select(SampleObject).where(
stmt = version_agnostic_select(SampleObject).where(
SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"])
)

output = [i for i in session.scalars(stmt)]
if sqlalchemy_1_3():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops. This change belongs in a separate commit. Basically: version 1.3 doesn't have the .scalars method whereas 1.4 does.

output = [i for i in session.execute(stmt)]
else:
output = [i for i in session.scalars(stmt)]

assert len(output) == 2

base.metadata.drop_all()
Expand All @@ -215,7 +272,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
metadata_obj,
Column("string_example", String(255)),
Column("integer_example", Integer),
Column("boolean_example", BOOLEAN),
Column("boolean_example", BOOLEAN(create_constraint=False)),
Column("decimal_example", DECIMAL(10, 2)),
Column("date_example", Date),
)
Expand All @@ -239,7 +296,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
with db_engine.connect() as conn:
conn.execute(insert_stmt)

select_stmt = select(SampleTable)
select_stmt = version_agnostic_select(SampleTable)
resp = db_engine.execute(select_stmt)

result = resp.fetchall()
Expand All @@ -252,3 +309,34 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
assert this_row["date_example"] == date_example

metadata_obj.drop_all()


def test_inspector_smoke_test(samples_engine: Engine):
"""It does not appear that 3L namespace is supported here"""

from sqlalchemy.engine.reflection import Inspector

schema, table = "nyctaxi", "trips"

try:
inspector = Inspector.from_engine(samples_engine)
except Exception as e:
assert False, f"Could not build inspector: {e}"

# Expect six columns
columns = inspector.get_columns(table, schema=schema)

# Expect zero views, but the method should return
views = inspector.get_view_names(schema=schema)

assert (
len(columns) == 6
), "Dialect did not find the expected number of columns in samples.nyctaxi.trips"
assert len(views) == 0, "Views could not be fetched"


def test_get_table_names_smoke_test(samples_engine: Engine):

with samples_engine.connect() as conn:
_names = samples_engine.table_names(schema="nyctaxi", connection=conn)
_names is not None, "get_table_names did not succeed"