6
6
from sqlalchemy .types import SMALLINT , Integer , BOOLEAN , String , DECIMAL , Date
7
7
from sqlalchemy .engine import Engine
8
8
9
+ from typing import Tuple
10
+
9
11
try :
10
12
from sqlalchemy .orm import declarative_base
11
13
except ImportError :
@@ -30,54 +32,42 @@ def version_agnostic_select(object_to_select, *args, **kwargs):
30
32
else :
31
33
return select (object_to_select , * args , ** kwargs )
32
34
33
-
34
-
35
- @pytest .fixture
36
- def db_engine () -> Engine :
37
-
35
+ def version_agnostic_connect_arguments (catalog = None , schema = None ) -> Tuple [str , dict ]:
36
+
38
37
HOST = os .environ .get ("host" )
39
38
HTTP_PATH = os .environ .get ("http_path" )
40
39
ACCESS_TOKEN = os .environ .get ("access_token" )
41
- CATALOG = os .environ .get ("catalog" )
42
- SCHEMA = os .environ .get ("schema" )
40
+ CATALOG = catalog or os .environ .get ("catalog" )
41
+ SCHEMA = schema or os .environ .get ("schema" )
43
42
44
- connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
45
-
46
- connect_args = {
47
- ** connect_args ,
48
- "http_path" : HTTP_PATH ,
49
- "server_hostname" : HOST ,
50
- "catalog" : CATALOG ,
51
- "schema" : SCHEMA
52
- }
53
-
54
- engine = create_engine (
55
- f"databricks://token:{ ACCESS_TOKEN } @{ HOST } " ,
56
- connect_args = connect_args ,
57
- )
58
- return engine
43
+ ua_connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
44
+
45
+ if sqlalchemy_1_3 ():
46
+ conn_string = f"databricks://token:{ ACCESS_TOKEN } @{ HOST } "
47
+ connect_args = {** ua_connect_args ,
48
+ "http_path" : HTTP_PATH ,
49
+ "server_hostname" : HOST ,
50
+ "catalog" : CATALOG ,
51
+ "schema" : SCHEMA
52
+ }
53
+
54
+ return conn_string , connect_args
55
+ else :
56
+ return f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } " , ua_connect_args
57
+
58
+
59
+
60
+
61
+ @pytest .fixture
62
+ def db_engine () -> Engine :
63
+ conn_string , connect_args = version_agnostic_connect_arguments ()
64
+ return create_engine (conn_string , connect_args = connect_args )
59
65
60
66
@pytest .fixture
61
67
def samples_engine () -> Engine :
62
- HOST = os .environ .get ("host" )
63
- HTTP_PATH = os .environ .get ("http_path" )
64
- ACCESS_TOKEN = os .environ .get ("access_token" )
65
- CATALOG = "samples"
66
68
67
- connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
68
-
69
- connect_args = {
70
- ** connect_args ,
71
- "http_path" : HTTP_PATH ,
72
- "server_hostname" : HOST ,
73
- "catalog" : CATALOG ,
74
- }
75
-
76
- engine = create_engine (
77
- f"databricks://token:{ ACCESS_TOKEN } @{ HOST } " ,
78
- connect_args = connect_args ,
79
- )
80
- return engine
69
+ conn_string , connect_args = version_agnostic_connect_arguments (catalog = "samples" , schema = "nyctaxi" )
70
+ return create_engine (conn_string , connect_args = connect_args )
81
71
82
72
83
73
@pytest .fixture ()
0 commit comments