Skip to content

Commit a71b8bc

Browse files
mdesmetebyhr
authored andcommitted
Integration tests for common sqlalchemy operations
1 parent ad0d76b commit a71b8bc

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License
12+
import pytest
13+
import sqlalchemy as sqla
14+
from sqlalchemy.sql import and_, or_, not_
15+
16+
17+
@pytest.fixture
18+
def trino_connection(run_trino, request):
19+
_, host, port = run_trino
20+
engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}",
21+
connect_args={"source": "test", "max_attempts": 1})
22+
yield engine, engine.connect()
23+
24+
25+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
26+
def test_select_query(trino_connection):
27+
_, conn = trino_connection
28+
metadata = sqla.MetaData()
29+
nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
30+
assert_column(nations, "nationkey", sqla.sql.sqltypes.BigInteger)
31+
assert_column(nations, "name", sqla.sql.sqltypes.String)
32+
assert_column(nations, "regionkey", sqla.sql.sqltypes.BigInteger)
33+
assert_column(nations, "comment", sqla.sql.sqltypes.String)
34+
query = sqla.select(nations)
35+
result = conn.execute(query)
36+
rows = result.fetchall()
37+
assert len(rows) == 25
38+
for row in rows:
39+
assert isinstance(row['nationkey'], int)
40+
assert isinstance(row['name'], str)
41+
assert isinstance(row['regionkey'], int)
42+
assert isinstance(row['comment'], str)
43+
44+
45+
def assert_column(table, column_name, column_type):
46+
assert getattr(table.c, column_name).name == column_name
47+
assert isinstance(getattr(table.c, column_name).type, column_type)
48+
49+
50+
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
51+
def test_select_specific_columns(trino_connection):
52+
_, conn = trino_connection
53+
metadata = sqla.MetaData()
54+
nodes = sqla.Table('nodes', metadata, schema='runtime', autoload_with=conn)
55+
assert_column(nodes, "node_id", sqla.sql.sqltypes.String)
56+
assert_column(nodes, "state", sqla.sql.sqltypes.String)
57+
query = sqla.select(nodes.c.node_id, nodes.c.state)
58+
result = conn.execute(query)
59+
rows = result.fetchall()
60+
assert len(rows) > 0
61+
for row in rows:
62+
assert isinstance(row['node_id'], str)
63+
assert isinstance(row['state'], str)
64+
65+
66+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
67+
def test_define_and_create_table(trino_connection):
68+
engine, conn = trino_connection
69+
if not engine.dialect.has_schema(engine, "test"):
70+
engine.execute(sqla.schema.CreateSchema("test"))
71+
metadata = sqla.MetaData()
72+
try:
73+
sqla.Table('users',
74+
metadata,
75+
sqla.Column('id', sqla.Integer),
76+
sqla.Column('name', sqla.String),
77+
sqla.Column('fullname', sqla.String),
78+
schema="test")
79+
metadata.create_all(engine)
80+
assert sqla.inspect(engine).has_table('users', schema="test")
81+
users = sqla.Table('users', metadata, schema='test', autoload_with=conn)
82+
assert_column(users, "id", sqla.sql.sqltypes.Integer)
83+
assert_column(users, "name", sqla.sql.sqltypes.String)
84+
assert_column(users, "fullname", sqla.sql.sqltypes.String)
85+
finally:
86+
metadata.drop_all(engine)
87+
88+
89+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
90+
def test_insert(trino_connection):
91+
engine, conn = trino_connection
92+
93+
if not engine.dialect.has_schema(engine, "test"):
94+
engine.execute(sqla.schema.CreateSchema("test"))
95+
metadata = sqla.MetaData()
96+
try:
97+
users = sqla.Table('users',
98+
metadata,
99+
sqla.Column('id', sqla.Integer),
100+
sqla.Column('name', sqla.String),
101+
sqla.Column('fullname', sqla.String),
102+
schema="test")
103+
metadata.create_all(engine)
104+
ins = users.insert()
105+
conn.execute(ins, {"id": 2, "name": "wendy", "fullname": "Wendy Williams"})
106+
query = sqla.select(users)
107+
result = conn.execute(query)
108+
rows = result.fetchall()
109+
assert len(rows) == 1
110+
assert rows[0] == (2, "wendy", "Wendy Williams")
111+
finally:
112+
metadata.drop_all(engine)
113+
114+
115+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
116+
def test_insert_multiple_statements(trino_connection):
117+
engine, conn = trino_connection
118+
if not engine.dialect.has_schema(engine, "test"):
119+
engine.execute(sqla.schema.CreateSchema("test"))
120+
metadata = sqla.MetaData()
121+
users = sqla.Table('users',
122+
metadata,
123+
sqla.Column('id', sqla.Integer),
124+
sqla.Column('name', sqla.String),
125+
sqla.Column('fullname', sqla.String),
126+
schema="test")
127+
metadata.create_all(engine)
128+
ins = users.insert()
129+
conn.execute(ins, [
130+
{"id": 2, "name": "wendy", "fullname": "Wendy Williams"},
131+
{"id": 3, "name": "john", "fullname": "John Doe"},
132+
{"id": 4, "name": "mary", "fullname": "Mary Hopkins"},
133+
])
134+
query = sqla.select(users)
135+
result = conn.execute(query)
136+
rows = result.fetchall()
137+
assert len(rows) == 3
138+
assert frozenset(rows) == frozenset([
139+
(2, "wendy", "Wendy Williams"),
140+
(3, "john", "John Doe"),
141+
(4, "mary", "Mary Hopkins"),
142+
])
143+
metadata.drop_all(engine)
144+
145+
146+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
147+
def test_operators(trino_connection):
148+
_, conn = trino_connection
149+
metadata = sqla.MetaData()
150+
customers = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
151+
query = sqla.select(customers).where(customers.c.nationkey == 2)
152+
result = conn.execute(query)
153+
rows = result.fetchall()
154+
assert len(rows) == 1
155+
for row in rows:
156+
assert isinstance(row['nationkey'], int)
157+
assert isinstance(row['name'], str)
158+
assert isinstance(row['regionkey'], int)
159+
assert isinstance(row['comment'], str)
160+
161+
162+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
163+
def test_conjunctions(trino_connection):
164+
_, conn = trino_connection
165+
metadata = sqla.MetaData()
166+
customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
167+
query = sqla.select(customers).where(and_(
168+
customers.c.name.like('%12%'),
169+
customers.c.nationkey == 15,
170+
or_(
171+
customers.c.mktsegment == 'AUTOMOBILE',
172+
customers.c.mktsegment == 'HOUSEHOLD'
173+
),
174+
not_(customers.c.acctbal < 0)))
175+
result = conn.execute(query)
176+
rows = result.fetchall()
177+
assert len(rows) == 1
178+
179+
180+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
181+
def test_textual_sql(trino_connection):
182+
_, conn = trino_connection
183+
s = sqla.text("SELECT * from tiny.customer where nationkey = :e1 AND acctbal < :e2")
184+
result = conn.execute(s, {"e1": 15, "e2": 0})
185+
rows = result.fetchall()
186+
assert len(rows) == 3
187+
for row in rows:
188+
assert isinstance(row['custkey'], int)
189+
assert isinstance(row['name'], str)
190+
assert isinstance(row['address'], str)
191+
assert isinstance(row['nationkey'], int)
192+
assert isinstance(row['phone'], str)
193+
assert isinstance(row['acctbal'], float)
194+
assert isinstance(row['mktsegment'], str)
195+
assert isinstance(row['comment'], str)
196+
197+
198+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
199+
def test_alias(trino_connection):
200+
_, conn = trino_connection
201+
metadata = sqla.MetaData()
202+
nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
203+
nations1 = nations.alias("o1")
204+
nations2 = nations.alias("o2")
205+
s = sqla.select(nations1) \
206+
.join(nations2, and_(
207+
nations1.c.regionkey == nations2.c.regionkey,
208+
nations1.c.nationkey != nations2.c.nationkey,
209+
nations1.c.regionkey == 1
210+
)) \
211+
.distinct()
212+
result = conn.execute(s)
213+
rows = result.fetchall()
214+
assert len(rows) == 5
215+
216+
217+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
218+
def test_subquery(trino_connection):
219+
_, conn = trino_connection
220+
metadata = sqla.MetaData()
221+
nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
222+
customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
223+
automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900)
224+
automobile_customers_subquery = automobile_customers.subquery()
225+
s = sqla.select(nations.c.name).where(nations.c.nationkey.in_(sqla.select(automobile_customers_subquery)))
226+
result = conn.execute(s)
227+
rows = result.fetchall()
228+
assert len(rows) == 15
229+
230+
231+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
232+
def test_joins(trino_connection):
233+
_, conn = trino_connection
234+
metadata = sqla.MetaData()
235+
nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
236+
customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
237+
s = sqla.select(nations.c.name) \
238+
.select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) \
239+
.where(customers.c.acctbal < -900) \
240+
.distinct()
241+
result = conn.execute(s)
242+
rows = result.fetchall()
243+
assert len(rows) == 15
244+
245+
246+
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
247+
def test_cte(trino_connection):
248+
_, conn = trino_connection
249+
metadata = sqla.MetaData()
250+
nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
251+
customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
252+
automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900)
253+
automobile_customers_cte = automobile_customers.cte()
254+
s = sqla.select(nations).where(nations.c.nationkey.in_(sqla.select(automobile_customers_cte)))
255+
result = conn.execute(s)
256+
rows = result.fetchall()
257+
assert len(rows) == 15

0 commit comments

Comments
 (0)