Skip to content

Commit 161b94e

Browse files
authored
Add e2e tests (databricks#12)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 6d62baa commit 161b94e

16 files changed

+1121
-10
lines changed

.github/workflows/code-quality-checks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: Code Quality Checks
22
on: [push]
33
jobs:
4-
run-tests:
4+
run-unit-tests:
55
runs-on: ubuntu-latest
66
steps:
77
#----------------------------------------------
@@ -48,7 +48,7 @@ jobs:
4848
# run test suite
4949
#----------------------------------------------
5050
- name: Run tests
51-
run: poetry run pytest tests/
51+
run: poetry run python -m pytest tests/unit
5252
check-linting:
5353
runs-on: ubuntu-latest
5454
steps:

CONTRIBUTING.md

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,64 @@ This project uses [Poetry](https://python-poetry.org/) for dependency management
99
1. Clone this respository
1010
2. Run `poetry install`
1111

12-
### Unit Tests
12+
### Run tests
1313

14-
We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run pytest`, all other arguments are passed directly to `pytest`.
14+
We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run python -m pytest`, all other arguments are passed directly to `pytest`.
15+
16+
#### Unit tests
17+
18+
Unit tests do not require a Databricks account.
1519

16-
#### All tests
1720
```bash
18-
poetry run pytest tests
21+
poetry run python -m pytest tests/unit
1922
```
20-
2123
#### Only a specific test file
2224

2325
```bash
24-
poetry run pytest tests/tests.py
26+
poetry run python -m pytest tests/unit/tests.py
2527
```
2628

2729
#### Only a specific method
2830

2931
```bash
30-
poetry run pytest tests/tests.py::ClientTestSuite::test_closing_connection_closes_commands
32+
poetry run python -m pytest tests/unit/tests.py::ClientTestSuite::test_closing_connection_closes_commands
33+
```
34+
35+
#### e2e Tests
36+
37+
End-to-end tests require a Databricks account. Before you can run them, you must set connection details for a Databricks SQL endpoint in your environment:
38+
39+
```bash
40+
export host=""
41+
export http_path=""
42+
export access_token=""
3143
```
3244

45+
There are several e2e test suites available:
46+
- `PySQLCoreTestSuite`
47+
- `PySQLLargeQueriesSuite`
48+
- `PySQLRetryTestSuite.HTTP503Suite` **[not documented]**
49+
- `PySQLRetryTestSuite.HTTP429Suite` **[not documented]**
50+
- `PySQLUnityCatalogTestSuite` **[not documented]**
51+
52+
To execute the core test suite:
53+
54+
```bash
55+
poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite
56+
```
57+
58+
The suites marked `[not documented]` require additional configuration which will be documented at a later time.
3359
### Code formatting
3460

3561
This project uses [Black](https://pypi.org/project/black/).
3662

3763
```
38-
poetry run black src
64+
poetry run python3 -m black src --check
3965
```
66+
67+
Remove the `--check` flag to write reformatted files to disk.
68+
69+
To simplify reviews you can format your changes in a separate commit.
4070
## Pull Request Process
4171

4272
1. Update the [CHANGELOG.md](README.md) or similar documentation with details of changes you wish to make, if applicable.

tests/__init__.py

Whitespace-only changes.

tests/e2e/common/__init__.py

Whitespace-only changes.

tests/e2e/common/core_tests.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import decimal
2+
import datetime
3+
from collections import namedtuple
4+
5+
TypeFailure = namedtuple(
6+
"TypeFailure", "query,columnType,resultType,resultValue,"
7+
"actualValue,actualType,description,conf")
8+
ResultFailure = namedtuple(
9+
"ResultFailure", "query,columnType,resultType,resultValue,"
10+
"actualValue,actualType,description,conf")
11+
ExecFailure = namedtuple(
12+
"ExecFailure", "query,columnType,resultType,resultValue,"
13+
"actualValue,actualType,description,conf,error")
14+
15+
16+
class SmokeTestMixin:
17+
def test_smoke_test(self):
18+
with self.cursor() as cursor:
19+
cursor.execute("select 0")
20+
rows = cursor.fetchall()
21+
self.assertEqual(len(rows), 1)
22+
self.assertEqual(rows[0][0], 0)
23+
24+
25+
class CoreTestMixin:
26+
"""
27+
This mixin expects to be mixed with a CursorTest-like class with the following extra attributes:
28+
validate_row_value_type: bool
29+
validate_result: bool
30+
"""
31+
32+
# A list of (subquery, column_type, python_type, expected_result)
33+
# To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}"
34+
range_queries = [
35+
("TRUE", 'boolean', bool, True),
36+
("cast(1 AS TINYINT)", 'byte', int, 1),
37+
("cast(1000 AS SMALLINT)", 'short', int, 1000),
38+
("cast(100000 AS INTEGER)", 'integer', int, 100000),
39+
("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000),
40+
("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001),
41+
("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)),
42+
("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch
43+
("'foo'", 'string', str, 'foo'),
44+
# SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days"
45+
# ("interval 30 days", str, str, "interval 4 weeks 2 days"),
46+
# ("interval 3 days", str, str, "interval 3 days"),
47+
("CAST(NULL AS DOUBLE)", 'double', type(None), None),
48+
]
49+
50+
# Full queries, only the first column of the first row is checked
51+
queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)]
52+
53+
def run_tests_on_queries(self, default_conf):
54+
failures = []
55+
for (query, columnType, rowValueType, answer) in self.range_queries:
56+
with self.cursor(default_conf) as cursor:
57+
failures.extend(
58+
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))
59+
failures.extend(
60+
self.run_range_query(cursor, query, columnType, rowValueType, answer,
61+
default_conf))
62+
63+
for (query, columnType, rowValueType, answer) in self.queries:
64+
with self.cursor(default_conf) as cursor:
65+
failures.extend(
66+
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))
67+
68+
if failures:
69+
self.fail("Failed testing result set with Arrow. "
70+
"Failed queries: {}".format("\n\n".join([str(f) for f in failures])))
71+
72+
def run_query(self, cursor, query, columnType, rowValueType, answer, conf):
73+
full_query = "SELECT {}".format(query)
74+
expected_column_types = self.expected_column_types(columnType)
75+
try:
76+
cursor.execute(full_query)
77+
(result, ) = cursor.fetchone()
78+
if not all(cursor.description[0][1] == type for type in expected_column_types):
79+
return [
80+
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
81+
type(result), cursor.description, conf)
82+
]
83+
if self.validate_row_value_type and type(result) is not rowValueType:
84+
return [
85+
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
86+
type(result), cursor.description, conf)
87+
]
88+
if self.validate_result and str(answer) != str(result):
89+
return [
90+
ResultFailure(full_query, query, expected_column_types, rowValueType, answer,
91+
result, type(result), cursor.description, conf)
92+
]
93+
return []
94+
except Exception as e:
95+
return [
96+
ExecFailure(full_query, columnType, rowValueType, None, None, None,
97+
cursor.description, conf, e)
98+
]
99+
100+
def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf):
101+
full_query = "SELECT {}, id FROM RANGE({})".format(query, 5000)
102+
expected_column_types = self.expected_column_types(columnType)
103+
try:
104+
cursor.execute(full_query)
105+
while True:
106+
rows = cursor.fetchmany(1000)
107+
if len(rows) <= 0:
108+
break
109+
for index, (result, id) in enumerate(rows):
110+
if not all(cursor.description[0][1] == type for type in expected_column_types):
111+
return [
112+
TypeFailure(full_query, expected_column_types, rowValueType, expected,
113+
result, type(result), cursor.description, conf)
114+
]
115+
if self.validate_row_value_type and type(result) \
116+
is not rowValueType:
117+
return [
118+
TypeFailure(full_query, expected_column_types, rowValueType, expected,
119+
result, type(result), cursor.description, conf)
120+
]
121+
if self.validate_result and str(expected) != str(result):
122+
return [
123+
ResultFailure(full_query, expected_column_types, rowValueType, expected,
124+
result, type(result), cursor.description, conf)
125+
]
126+
return []
127+
except Exception as e:
128+
return [
129+
ExecFailure(full_query, columnType, rowValueType, None, None, None,
130+
cursor.description, conf, e)
131+
]

tests/e2e/common/decimal_tests.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from decimal import Decimal
2+
3+
import pyarrow
4+
5+
6+
class DecimalTestsMixin:
7+
decimal_and_expected_results = [
8+
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
9+
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
10+
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
11+
# TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False
12+
#("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)),
13+
("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)),
14+
("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)),
15+
("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)),
16+
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
17+
]
18+
19+
multi_decimals_and_expected_results = [
20+
(["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
21+
[Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)),
22+
(["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'),
23+
Decimal('2.000')], pyarrow.decimal128(6,
24+
3)),
25+
]
26+
27+
def test_decimals(self):
28+
with self.cursor({}) as cursor:
29+
for (decimal, expected_value, expected_type) in self.decimal_and_expected_results:
30+
query = "SELECT CAST ({})".format(decimal)
31+
with self.subTest(query=query):
32+
cursor.execute(query)
33+
table = cursor.fetchmany_arrow(1)
34+
self.assertEqual(table.field(0).type, expected_type)
35+
self.assertEqual(table.to_pydict().popitem()[1][0], expected_value)
36+
37+
def test_multi_decimals(self):
38+
with self.cursor({}) as cursor:
39+
for (decimals, expected_values,
40+
expected_type) in self.multi_decimals_and_expected_results:
41+
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
42+
query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str)
43+
44+
with self.subTest(query=query):
45+
cursor.execute(query)
46+
table = cursor.fetchall_arrow()
47+
self.assertEqual(table.field(0).type, expected_type)
48+
self.assertEqual(table.to_pydict().popitem()[1], expected_values)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import logging
2+
import math
3+
import time
4+
5+
log = logging.getLogger(__name__)
6+
7+
8+
class LargeQueriesMixin:
9+
"""
10+
This mixin expects to be mixed with a CursorTest-like class
11+
"""
12+
13+
def fetch_rows(self, cursor, row_count, fetchmany_size):
14+
"""
15+
A generator for rows. Fetches until the end or up to 5 minutes.
16+
"""
17+
# TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone
18+
# in the Python client
19+
max_fetch_time = 5 * 60 # Fetch for at most 5 minutes
20+
21+
rows = self.get_some_rows(cursor, fetchmany_size)
22+
start_time = time.time()
23+
n = 0
24+
while rows:
25+
for row in rows:
26+
n += 1
27+
yield row
28+
if time.time() - start_time >= max_fetch_time:
29+
log.warning("Fetching rows timed out")
30+
break
31+
rows = self.get_some_rows(cursor, fetchmany_size)
32+
if not rows:
33+
# Read all the rows, row_count should match
34+
self.assertEqual(n, row_count)
35+
36+
num_fetches = max(math.ceil(n / 10000), 1)
37+
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
38+
print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) +
39+
'assuming 10K fetch size.')
40+
41+
def test_query_with_large_wide_result_set(self):
42+
resultSize = 300 * 1000 * 1000 # 300 MB
43+
width = 8192 # B
44+
rows = resultSize // width
45+
cols = width // 36
46+
47+
# Set the fetchmany_size to get 10MB of data a go
48+
fetchmany_size = 10 * 1024 * 1024 // width
49+
# This is used by PyHive tests to determine the buffer size
50+
self.arraysize = 1000
51+
with self.cursor() as cursor:
52+
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
53+
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
54+
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
55+
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
56+
self.assertEqual(len(row[1]), 36)
57+
58+
def test_query_with_large_narrow_result_set(self):
59+
resultSize = 300 * 1000 * 1000 # 300 MB
60+
width = 8 # sizeof(long)
61+
rows = resultSize / width
62+
63+
# Set the fetchmany_size to get 10MB of data a go
64+
fetchmany_size = 10 * 1024 * 1024 // width
65+
# This is used by PyHive tests to determine the buffer size
66+
self.arraysize = 10000000
67+
with self.cursor() as cursor:
68+
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
69+
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
70+
self.assertEqual(row[0], row_id)
71+
72+
def test_long_running_query(self):
73+
""" Incrementally increase query size until it takes at least 5 minutes,
74+
and asserts that the query completes successfully.
75+
"""
76+
minutes = 60
77+
min_duration = 5 * minutes
78+
79+
duration = -1
80+
scale0 = 10000
81+
scale_factor = 1
82+
with self.cursor() as cursor:
83+
while duration < min_duration:
84+
self.assertLess(scale_factor, 512, msg="Detected infinite loop")
85+
start = time.time()
86+
87+
cursor.execute("""SELECT count(*)
88+
FROM RANGE({scale}) x
89+
JOIN RANGE({scale0}) y
90+
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
91+
""".format(scale=scale_factor * scale0, scale0=scale0))
92+
93+
n, = cursor.fetchone()
94+
self.assertEqual(n, 0)
95+
96+
duration = time.time() - start
97+
current_fraction = duration / min_duration
98+
print('Took {} s with scale factor={}'.format(duration, scale_factor))
99+
# Extrapolate linearly to reach 5 min and add 50% padding to push over the limit
100+
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)

0 commit comments

Comments
 (0)