|
| 1 | +import math |
| 2 | +import pytest |
| 3 | +from decimal import Decimal |
| 4 | +import trino |
| 5 | + |
| 6 | + |
| 7 | +@pytest.fixture |
| 8 | +def trino_connection(run_trino): |
| 9 | + _, host, port = run_trino |
| 10 | + |
| 11 | + yield trino.dbapi.Connection( |
| 12 | + host=host, port=port, user="test", source="test", max_attempts=1 |
| 13 | + ) |
| 14 | + |
| 15 | + |
| 16 | +def test_boolean(trino_connection): |
| 17 | + SqlTest(trino_connection) \ |
| 18 | + .add_field(sql="CAST(null AS BOOLEAN)", python=None) \ |
| 19 | + .add_field(sql="false", python=False) \ |
| 20 | + .add_field(sql="true", python=True) \ |
| 21 | + .execute() |
| 22 | + |
| 23 | + |
| 24 | +def test_tinyint(trino_connection): |
| 25 | + SqlTest(trino_connection) \ |
| 26 | + .add_field(sql="CAST(null AS TINYINT)", python=None) \ |
| 27 | + .add_field(sql="CAST(-128 AS TINYINT)", python=-128) \ |
| 28 | + .add_field(sql="CAST(42 AS TINYINT)", python=42) \ |
| 29 | + .add_field(sql="CAST(127 AS TINYINT)", python=127) \ |
| 30 | + .execute() |
| 31 | + |
| 32 | + |
| 33 | +def test_smallint(trino_connection): |
| 34 | + SqlTest(trino_connection) \ |
| 35 | + .add_field(sql="CAST(null AS SMALLINT)", python=None) \ |
| 36 | + .add_field(sql="CAST(-32768 AS SMALLINT)", python=-32768) \ |
| 37 | + .add_field(sql="CAST(42 AS SMALLINT)", python=42) \ |
| 38 | + .add_field(sql="CAST(32767 AS SMALLINT)", python=32767) \ |
| 39 | + .execute() |
| 40 | + |
| 41 | + |
| 42 | +def test_int(trino_connection): |
| 43 | + SqlTest(trino_connection) \ |
| 44 | + .add_field(sql="CAST(null AS INTEGER)", python=None) \ |
| 45 | + .add_field(sql="CAST(-2147483648 AS INTEGER)", python=-2147483648) \ |
| 46 | + .add_field(sql="CAST(83648 AS INTEGER)", python=83648) \ |
| 47 | + .add_field(sql="CAST(2147483647 AS INTEGER)", python=2147483647) \ |
| 48 | + .execute() |
| 49 | + |
| 50 | + |
| 51 | +def test_bigint(trino_connection): |
| 52 | + SqlTest(trino_connection) \ |
| 53 | + .add_field(sql="CAST(null AS BIGINT)", python=None) \ |
| 54 | + .add_field(sql="CAST(-9223372036854775808 AS BIGINT)", python=-9223372036854775808) \ |
| 55 | + .add_field(sql="CAST(9223 AS BIGINT)", python=9223) \ |
| 56 | + .add_field(sql="CAST(9223372036854775807 AS BIGINT)", python=9223372036854775807) \ |
| 57 | + .execute() |
| 58 | + |
| 59 | + |
| 60 | +def test_real(trino_connection): |
| 61 | + SqlTest(trino_connection) \ |
| 62 | + .add_field(sql="CAST(null AS REAL)", python=None) \ |
| 63 | + .add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \ |
| 64 | + .add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \ |
| 65 | + .add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \ |
| 66 | + .add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \ |
| 67 | + .add_field(sql="CAST('Infinity' AS REAL)", python=math.inf) \ |
| 68 | + .execute() |
| 69 | + |
| 70 | + |
| 71 | +def test_double(trino_connection): |
| 72 | + SqlTest(trino_connection) \ |
| 73 | + .add_field(sql="CAST(null AS DOUBLE)", python=None) \ |
| 74 | + .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \ |
| 75 | + .add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \ |
| 76 | + .add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \ |
| 77 | + .add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \ |
| 78 | + .add_field(sql="CAST('Infinity' AS DOUBLE)", python=math.inf) \ |
| 79 | + .execute() |
| 80 | + |
| 81 | + |
| 82 | +def test_decimal(trino_connection): |
| 83 | + SqlTest(trino_connection) \ |
| 84 | + .add_field(sql="CAST(null AS DECIMAL)", python=None) \ |
| 85 | + .add_field(sql="CAST(null AS DECIMAL(38,0))", python=None) \ |
| 86 | + .add_field(sql="DECIMAL '10.3'", python=Decimal('10.3')) \ |
| 87 | + .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,18))", python=Decimal('0.123456789123456789')) \ |
| 88 | + .add_field(sql="CAST(null AS DECIMAL(18,18))", python=None) \ |
| 89 | + .add_field(sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal('234.1235')) \ |
| 90 | + .add_field(sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal('10.3')) \ |
| 91 | + .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal('0.12')) \ |
| 92 | + .add_field(sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal('0.3123')) \ |
| 93 | + .execute() |
| 94 | + |
| 95 | + |
| 96 | +def test_varchar(trino_connection): |
| 97 | + SqlTest(trino_connection) \ |
| 98 | + .add_field(sql="'aaa'", python='aaa') \ |
| 99 | + .add_field(sql="U&'Hello winter \2603 !'", python='Hello winter °3 !') \ |
| 100 | + .add_field(sql="CAST(null AS VARCHAR)", python=None) \ |
| 101 | + .add_field(sql="CAST('bbb' AS VARCHAR(1))", python='b') \ |
| 102 | + .add_field(sql="CAST(null AS VARCHAR(1))", python=None) \ |
| 103 | + .execute() |
| 104 | + |
| 105 | + |
| 106 | +def test_char(trino_connection): |
| 107 | + SqlTest(trino_connection) \ |
| 108 | + .add_field(sql="CAST('ccc' AS CHAR)", python='c') \ |
| 109 | + .add_field(sql="CAST(null AS CHAR)", python=None) \ |
| 110 | + .add_field(sql="CAST('ddd' AS CHAR(1))", python='d') \ |
| 111 | + .add_field(sql="CAST('😂' AS CHAR(1))", python='😂') \ |
| 112 | + .add_field(sql="CAST(null AS CHAR(1))", python=None) \ |
| 113 | + .execute() |
| 114 | + |
| 115 | + |
| 116 | +def test_varbinary(trino_connection): |
| 117 | + SqlTest(trino_connection) \ |
| 118 | + .add_field(sql="X'65683F'", python='ZWg/') \ |
| 119 | + .add_field(sql="X''", python='') \ |
| 120 | + .add_field(sql="CAST('' AS VARBINARY)", python='') \ |
| 121 | + .add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \ |
| 122 | + .add_field(sql="CAST(null AS VARBINARY)", python=None) \ |
| 123 | + .execute() |
| 124 | + |
| 125 | + |
| 126 | +def test_varbinary_failure(trino_connection): |
| 127 | + SqlExpectFailureTest(trino_connection) \ |
| 128 | + .execute("CAST(42 AS VARBINARY)") |
| 129 | + |
| 130 | + |
| 131 | +def test_json(trino_connection): |
| 132 | + SqlTest(trino_connection) \ |
| 133 | + .add_field(sql="CAST('{}' AS JSON)", python='"{}"') \ |
| 134 | + .add_field(sql="CAST('null' AS JSON)", python='"null"') \ |
| 135 | + .add_field(sql="CAST(null AS JSON)", python=None) \ |
| 136 | + .add_field(sql="CAST('3.14' AS JSON)", python='"3.14"') \ |
| 137 | + .add_field(sql="CAST('a string' AS JSON)", python='"a string"') \ |
| 138 | + .add_field(sql="CAST('a \" complex '' string :' AS JSON)", python='"a \\" complex \' string :"') \ |
| 139 | + .add_field(sql="CAST('[]' AS JSON)", python='"[]"') \ |
| 140 | + .execute() |
| 141 | + |
| 142 | + |
| 143 | +def test_interval(trino_connection): |
| 144 | + SqlTest(trino_connection) \ |
| 145 | + .add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \ |
| 146 | + .add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \ |
| 147 | + .add_field(sql="INTERVAL '3' MONTH", python='0-3') \ |
| 148 | + .add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \ |
| 149 | + .add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \ |
| 150 | + .execute() |
| 151 | + |
| 152 | + |
| 153 | +def test_array(trino_connection): |
| 154 | + SqlTest(trino_connection) \ |
| 155 | + .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ |
| 156 | + .add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \ |
| 157 | + .execute() |
| 158 | + |
| 159 | + |
| 160 | +def test_map(trino_connection): |
| 161 | + SqlTest(trino_connection) \ |
| 162 | + .add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \ |
| 163 | + .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \ |
| 164 | + .execute() |
| 165 | + |
| 166 | + |
| 167 | +def test_row(trino_connection): |
| 168 | + SqlTest(trino_connection) \ |
| 169 | + .add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \ |
| 170 | + .add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \ |
| 171 | + .execute() |
| 172 | + |
| 173 | + |
| 174 | +def test_ipaddress(trino_connection): |
| 175 | + SqlTest(trino_connection) \ |
| 176 | + .add_field(sql="CAST(null AS IPADDRESS)", python=None) \ |
| 177 | + .add_field(sql="IPADDRESS '2001:db8::1'", python='2001:db8::1') \ |
| 178 | + .execute() |
| 179 | + |
| 180 | + |
| 181 | +def test_uuid(trino_connection): |
| 182 | + SqlTest(trino_connection) \ |
| 183 | + .add_field(sql="CAST(null AS UUID)", python=None) \ |
| 184 | + .add_field(sql="UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59'", python='12151fd2-7586-11e9-8f9e-2a86e4085a59') \ |
| 185 | + .execute() |
| 186 | + |
| 187 | + |
| 188 | +def test_digest(trino_connection): |
| 189 | + SqlTest(trino_connection) \ |
| 190 | + .add_field(sql="CAST(null AS HyperLogLog)", python=None) \ |
| 191 | + .add_field(sql="CAST(null AS P4HyperLogLog)", python=None) \ |
| 192 | + .add_field(sql="CAST(null AS SetDigest)", python=None) \ |
| 193 | + .add_field(sql="CAST(null AS QDigest(BIGINT))", python=None) \ |
| 194 | + .add_field(sql="CAST(null AS TDigest)", python=None) \ |
| 195 | + .add_field(sql="approx_set(1)", python='AgwBAIADRAA=') \ |
| 196 | + .add_field(sql="CAST(approx_set(1) AS P4HyperLogLog)", python='AwwAAAAg' + 'A' * 2730 + '==') \ |
| 197 | + .add_field(sql="make_set_digest(1)", python='AQgAAAACCwEAgANEAAAgAAABAAAASsQF+7cDRAABAA==') \ |
| 198 | + .add_field(sql="tdigest_agg(1)", |
| 199 | + python='AAAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8=') \ |
| 200 | + .execute() |
| 201 | + |
| 202 | + |
| 203 | +class SqlTest: |
| 204 | + def __init__(self, trino_connection): |
| 205 | + self.cur = trino_connection.cursor(experimental_python_types=True) |
| 206 | + self.sql_args = [] |
| 207 | + self.expected_result = [] |
| 208 | + |
| 209 | + def add_field(self, sql, python): |
| 210 | + self.sql_args.append(sql) |
| 211 | + self.expected_result.append(python) |
| 212 | + return self |
| 213 | + |
| 214 | + def execute(self): |
| 215 | + sql = 'SELECT ' + ',\n'.join(self.sql_args) |
| 216 | + |
| 217 | + self.cur.execute(sql) |
| 218 | + actual_result = self.cur.fetchall() |
| 219 | + self._compare_results(actual_result[0], self.expected_result) |
| 220 | + |
| 221 | + def _compare_results(self, actual, expected): |
| 222 | + assert len(actual) == len(expected) |
| 223 | + |
| 224 | + for idx, actual_val in enumerate(actual): |
| 225 | + expected_val = expected[idx] |
| 226 | + if type(actual_val) == float and math.isnan(actual_val) \ |
| 227 | + and type(expected_val) == float and math.isnan(expected_val): |
| 228 | + continue |
| 229 | + |
| 230 | + assert actual_val == expected_val |
| 231 | + |
| 232 | + |
| 233 | +class SqlExpectFailureTest: |
| 234 | + def __init__(self, trino_connection): |
| 235 | + self.cur = trino_connection.cursor(experimental_python_types=True) |
| 236 | + |
| 237 | + def execute(self, field): |
| 238 | + sql = 'SELECT ' + field |
| 239 | + |
| 240 | + try: |
| 241 | + self.cur.execute(sql) |
| 242 | + self.cur.fetchall() |
| 243 | + success = True |
| 244 | + except Exception: |
| 245 | + success = False |
| 246 | + |
| 247 | + assert not success, "Test not expected to succeed" |
0 commit comments