2
2
import pytest
3
3
from unittest import skipIf
4
4
from sqlalchemy import create_engine , select , insert , Column , MetaData , Table
5
- from sqlalchemy .orm import declarative_base , Session
5
+ from sqlalchemy .orm import Session
6
6
from sqlalchemy .types import SMALLINT , Integer , BOOLEAN , String , DECIMAL , Date
7
+ from sqlalchemy .engine import Engine
8
+
9
+ from typing import Tuple
10
+
11
+ try :
12
+ from sqlalchemy .orm import declarative_base
13
+ except ImportError :
14
+ from sqlalchemy .ext .declarative import declarative_base
7
15
8
16
9
17
USER_AGENT_TOKEN = "PySQL e2e Tests"
10
18
11
19
12
- @pytest .fixture
13
- def db_engine ():
20
+ def sqlalchemy_1_3 ():
21
+ import sqlalchemy
22
+
23
+ return sqlalchemy .__version__ .startswith ("1.3" )
24
+
25
+
26
+ def version_agnostic_select (object_to_select , * args , ** kwargs ):
27
+ """
28
+ SQLAlchemy==1.3.x requires arguments to select() to be a Python list
29
+
30
+ https://docs.sqlalchemy.org/en/20/changelog/migration_14.html#orm-query-is-internally-unified-with-select-update-delete-2-0-style-execution-available
31
+ """
32
+
33
+ if sqlalchemy_1_3 ():
34
+ return select ([object_to_select ], * args , ** kwargs )
35
+ else :
36
+ return select (object_to_select , * args , ** kwargs )
37
+
38
+
39
+ def version_agnostic_connect_arguments (catalog = None , schema = None ) -> Tuple [str , dict ]:
14
40
15
41
HOST = os .environ .get ("host" )
16
42
HTTP_PATH = os .environ .get ("http_path" )
17
43
ACCESS_TOKEN = os .environ .get ("access_token" )
18
- CATALOG = os .environ .get ("catalog" )
19
- SCHEMA = os .environ .get ("schema" )
44
+ CATALOG = catalog or os .environ .get ("catalog" )
45
+ SCHEMA = schema or os .environ .get ("schema" )
46
+
47
+ ua_connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
48
+
49
+ if sqlalchemy_1_3 ():
50
+ conn_string = f"databricks://token:{ ACCESS_TOKEN } @{ HOST } "
51
+ connect_args = {
52
+ ** ua_connect_args ,
53
+ "http_path" : HTTP_PATH ,
54
+ "server_hostname" : HOST ,
55
+ "catalog" : CATALOG ,
56
+ "schema" : SCHEMA ,
57
+ }
58
+
59
+ return conn_string , connect_args
60
+ else :
61
+ return (
62
+ f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } " ,
63
+ ua_connect_args ,
64
+ )
65
+
66
+
67
+ @pytest .fixture
68
+ def db_engine () -> Engine :
69
+ conn_string , connect_args = version_agnostic_connect_arguments ()
70
+ return create_engine (conn_string , connect_args = connect_args )
20
71
21
- connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
22
72
23
- engine = create_engine (
24
- f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } " ,
25
- connect_args = connect_args ,
73
+ @pytest .fixture
74
+ def samples_engine () -> Engine :
75
+
76
+ conn_string , connect_args = version_agnostic_connect_arguments (
77
+ catalog = "samples" , schema = "nyctaxi"
26
78
)
27
- return engine
79
+ return create_engine ( conn_string , connect_args = connect_args )
28
80
29
81
30
82
@pytest .fixture ()
@@ -62,6 +114,7 @@ def test_connect_args(db_engine):
62
114
assert expected in user_agent
63
115
64
116
117
+ @pytest .mark .skipif (sqlalchemy_1_3 (), reason = "Pandas requires SQLAlchemy >= 1.4" )
65
118
def test_pandas_upload (db_engine , metadata_obj ):
66
119
67
120
import pandas as pd
@@ -86,7 +139,7 @@ def test_pandas_upload(db_engine, metadata_obj):
86
139
db_engine .execute ("DROP TABLE mock_data" )
87
140
88
141
89
- def test_create_table_not_null (db_engine , metadata_obj ):
142
+ def test_create_table_not_null (db_engine , metadata_obj : MetaData ):
90
143
91
144
table_name = "PySQLTest_{}" .format (datetime .datetime .utcnow ().strftime ("%s" ))
92
145
@@ -95,7 +148,7 @@ def test_create_table_not_null(db_engine, metadata_obj):
95
148
metadata_obj ,
96
149
Column ("name" , String (255 )),
97
150
Column ("episodes" , Integer ),
98
- Column ("some_bool" , BOOLEAN , nullable = False ),
151
+ Column ("some_bool" , BOOLEAN ( create_constraint = False ) , nullable = False ),
99
152
)
100
153
101
154
metadata_obj .create_all ()
@@ -135,7 +188,7 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
135
188
metadata_obj .create_all ()
136
189
db_engine .execute (insert (SampleTable ).values (rows ))
137
190
138
- rows = db_engine .execute (select (SampleTable )).fetchall ()
191
+ rows = db_engine .execute (version_agnostic_select (SampleTable )).fetchall ()
139
192
140
193
assert len (rows ) == num_to_insert
141
194
@@ -148,7 +201,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
148
201
metadata_obj ,
149
202
Column ("name" , String (255 )),
150
203
Column ("episodes" , Integer ),
151
- Column ("some_bool" , BOOLEAN ),
204
+ Column ("some_bool" , BOOLEAN ( create_constraint = False ) ),
152
205
Column ("dollars" , DECIMAL (10 , 2 )),
153
206
)
154
207
@@ -161,7 +214,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
161
214
with db_engine .connect () as conn :
162
215
conn .execute (insert_stmt )
163
216
164
- select_stmt = select (SampleTable )
217
+ select_stmt = version_agnostic_select (SampleTable )
165
218
resp = db_engine .execute (select_stmt )
166
219
167
220
result = resp .fetchall ()
@@ -187,7 +240,7 @@ class SampleObject(base):
187
240
188
241
name = Column (String (255 ), primary_key = True )
189
242
episodes = Column (Integer )
190
- some_bool = Column (BOOLEAN )
243
+ some_bool = Column (BOOLEAN ( create_constraint = False ) )
191
244
192
245
base .metadata .create_all ()
193
246
@@ -197,11 +250,15 @@ class SampleObject(base):
197
250
session .add (sample_object_2 )
198
251
session .commit ()
199
252
200
- stmt = select (SampleObject ).where (
253
+ stmt = version_agnostic_select (SampleObject ).where (
201
254
SampleObject .name .in_ (["Bim Adewunmi" , "Miki Meek" ])
202
255
)
203
256
204
- output = [i for i in session .scalars (stmt )]
257
+ if sqlalchemy_1_3 ():
258
+ output = [i for i in session .execute (stmt )]
259
+ else :
260
+ output = [i for i in session .scalars (stmt )]
261
+
205
262
assert len (output ) == 2
206
263
207
264
base .metadata .drop_all ()
@@ -215,7 +272,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
215
272
metadata_obj ,
216
273
Column ("string_example" , String (255 )),
217
274
Column ("integer_example" , Integer ),
218
- Column ("boolean_example" , BOOLEAN ),
275
+ Column ("boolean_example" , BOOLEAN ( create_constraint = False ) ),
219
276
Column ("decimal_example" , DECIMAL (10 , 2 )),
220
277
Column ("date_example" , Date ),
221
278
)
@@ -239,7 +296,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
239
296
with db_engine .connect () as conn :
240
297
conn .execute (insert_stmt )
241
298
242
- select_stmt = select (SampleTable )
299
+ select_stmt = version_agnostic_select (SampleTable )
243
300
resp = db_engine .execute (select_stmt )
244
301
245
302
result = resp .fetchall ()
@@ -252,3 +309,34 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
252
309
assert this_row ["date_example" ] == date_example
253
310
254
311
metadata_obj .drop_all ()
312
+
313
+
314
+ def test_inspector_smoke_test (samples_engine : Engine ):
315
+ """It does not appear that 3L namespace is supported here"""
316
+
317
+ from sqlalchemy .engine .reflection import Inspector
318
+
319
+ schema , table = "nyctaxi" , "trips"
320
+
321
+ try :
322
+ inspector = Inspector .from_engine (samples_engine )
323
+ except Exception as e :
324
+ assert False , f"Could not build inspector: { e } "
325
+
326
+ # Expect six columns
327
+ columns = inspector .get_columns (table , schema = schema )
328
+
329
+ # Expect zero views, but the method should return
330
+ views = inspector .get_view_names (schema = schema )
331
+
332
+ assert (
333
+ len (columns ) == 6
334
+ ), "Dialect did not find the expected number of columns in samples.nyctaxi.trips"
335
+ assert len (views ) == 0 , "Views could not be fetched"
336
+
337
+
338
+ def test_get_table_names_smoke_test (samples_engine : Engine ):
339
+
340
+ with samples_engine .connect () as conn :
341
+ _names = samples_engine .table_names (schema = "nyctaxi" , connection = conn )
342
+ _names is not None , "get_table_names did not succeed"
0 commit comments