Skip to content

Commit cd614ff

Browse files
laserkaplanhashhar
authored andcommitted
Add JSON handling to dbapi and dialect
1 parent 89c2769 commit cd614ff

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

tests/integration/test_sqlalchemy_integration.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import sqlalchemy as sqla
1414
from sqlalchemy.sql import and_, or_, not_
1515

16+
from trino.sqlalchemy.datatype import JSON
17+
1618

1719
@pytest.fixture
1820
def trino_connection(run_trino, request):
@@ -255,3 +257,41 @@ def test_cte(trino_connection):
255257
result = conn.execute(s)
256258
rows = result.fetchall()
257259
assert len(rows) == 15
260+
261+
262+
@pytest.mark.parametrize(
263+
'trino_connection,json_object',
264+
[
265+
('memory', None),
266+
('memory', 1),
267+
('memory', 'test'),
268+
('memory', [1, 'test']),
269+
('memory', {'test': 1}),
270+
],
271+
indirect=['trino_connection']
272+
)
273+
def test_json_column(trino_connection, json_object):
274+
engine, conn = trino_connection
275+
276+
if not engine.dialect.has_schema(engine, "test"):
277+
engine.execute(sqla.schema.CreateSchema("test"))
278+
metadata = sqla.MetaData()
279+
280+
try:
281+
table_with_json = sqla.Table(
282+
'table_with_json',
283+
metadata,
284+
sqla.Column('id', sqla.Integer),
285+
sqla.Column('json_column', JSON),
286+
schema="test"
287+
)
288+
metadata.create_all(engine)
289+
ins = table_with_json.insert()
290+
conn.execute(ins, {"id": 1, "json_column": json_object})
291+
query = sqla.select(table_with_json)
292+
result = conn.execute(query)
293+
rows = result.fetchall()
294+
assert len(rows) == 1
295+
assert rows[0] == (1, json_object)
296+
finally:
297+
metadata.drop_all(engine)

trino/sqlalchemy/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def visit_TIME(self, type_, **kw):
206206
datatype += " WITH TIME ZONE"
207207
return datatype
208208

209+
def visit_JSON(self, type_, **kw):
210+
return 'JSON'
211+
209212

210213
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
211214
reserved_words = RESERVED_WORDS

trino/sqlalchemy/datatype.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
import json
1213
import re
1314
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any
1415

1516
from sqlalchemy import util
1617
from sqlalchemy.sql import sqltypes
17-
from sqlalchemy.sql.type_api import TypeEngine
18+
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
19+
from sqlalchemy.types import String
1820

1921
SQLType = Union[TypeEngine, Type[TypeEngine]]
2022

@@ -71,6 +73,19 @@ def __init__(self, precision=None, timezone=False):
7173
self.precision = precision
7274

7375

76+
class JSON(TypeDecorator):
77+
impl = String
78+
79+
def process_bind_param(self, value, dialect):
80+
return json.dumps(value)
81+
82+
def process_result_value(self, value, dialect):
83+
return json.loads(value)
84+
85+
def get_col_spec(self, **kw):
86+
return 'JSON'
87+
88+
7489
# https://trino.io/docs/current/language/types.html
7590
_type_map = {
7691
# === Boolean ===
@@ -90,7 +105,7 @@ def __init__(self, precision=None, timezone=False):
90105
"varchar": sqltypes.VARCHAR,
91106
"char": sqltypes.CHAR,
92107
"varbinary": sqltypes.VARBINARY,
93-
"json": sqltypes.JSON,
108+
"json": JSON,
94109
# === Date and time ===
95110
"date": sqltypes.DATE,
96111
"time": TIME,

0 commit comments

Comments
 (0)