Skip to content

Commit 3ac3aab

Browse files
committed
Escape url parameters in sqlalchemy connection strings
1 parent 666d934 commit 3ac3aab

File tree

5 files changed

+180
-27
lines changed

5 files changed

+180
-27
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,13 @@ NOTE: `password` and `schema` are optional
7272
from sqlalchemy import create_engine
7373
from sqlalchemy.schema import Table, MetaData
7474
from sqlalchemy.sql.expression import select, text
75+
from trino.sqlalchemy import URL
7576

76-
engine = create_engine('trino://user@localhost:8080/system')
77+
engine = create_engine(URL(
78+
host="localhost",
79+
port=8080,
80+
catalog="system"
81+
))
7782
connection = engine.connect()
7883

7984
rows = connection.execute(text("SELECT * FROM runtime.nodes")).fetchall()
@@ -93,6 +98,7 @@ Attributes can also be passed in the connection string.
9398

9499
```python
95100
from sqlalchemy import create_engine
101+
from trino.sqlalchemy import URL
96102

97103
engine = create_engine(
98104
'trino://user@localhost:8080/system',
@@ -110,6 +116,14 @@ engine = create_engine(
110116
'&client_tags=["tag1", "tag2"]'
111117
'&experimental_python_types=true',
112118
)
119+
120+
# or using the URL factory method
121+
engine = create_engine(URL(
122+
host="localhost",
123+
port=8080,
124+
client_tags=["tag1", "tag2"],
125+
experimental_python_types=True
126+
))
113127
```
114128

115129
## Authentication mechanisms

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from trino.dbapi import Connection
1010
from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect
1111
from trino.transaction import IsolationLevel
12+
from trino.sqlalchemy import URL as trino_url
1213

1314

1415
class TestTrinoDialect:
@@ -19,17 +20,29 @@ def setup(self):
1920
"url, expected_args, expected_kwargs",
2021
[
2122
(
22-
make_url("trino://user@localhost"),
23+
make_url(trino_url(
24+
user="user",
25+
host="localhost",
26+
)),
2327
list(),
24-
dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"),
28+
dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"),
2529
),
2630
(
27-
make_url("trino://user@localhost:8080"),
31+
make_url(trino_url(
32+
user="user",
33+
host="localhost",
34+
port=443,
35+
)),
2836
list(),
29-
dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
37+
dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"),
3038
),
3139
(
32-
make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
40+
make_url(trino_url(
41+
user="user",
42+
password="pass",
43+
host="localhost",
44+
source="trino-rulez",
45+
)),
3346
list(),
3447
dict(
3548
host="localhost",
@@ -42,13 +55,15 @@ def setup(self):
4255
),
4356
),
4457
(
45-
make_url(
46-
'trino://user@localhost:8080?'
47-
'session_properties={"query_max_run_time": "1d"}'
48-
'&http_headers={"trino": 1}'
49-
'&extra_credential=[("a", "b"), ("c", "d")]'
50-
'&client_tags=[1, "sql"]'
51-
'&experimental_python_types=true'),
58+
make_url(trino_url(
59+
user="user",
60+
host="localhost",
61+
session_properties={"query_max_run_time": "1d"},
62+
http_headers={"trino": 1},
63+
extra_credential=[("a", "b"), ("c", "d")],
64+
client_tags=["1", "sql"],
65+
experimental_python_types=True,
66+
)),
5267
list(),
5368
dict(
5469
host="localhost",
@@ -59,8 +74,40 @@ def setup(self):
5974
session_properties={"query_max_run_time": "1d"},
6075
http_headers={"trino": 1},
6176
extra_credential=[("a", "b"), ("c", "d")],
62-
client_tags=[1, "sql"],
77+
client_tags=["1", "sql"],
78+
experimental_python_types=True,
79+
),
80+
),
81+
# url encoding
82+
(
83+
make_url(trino_url(
84+
user="[email protected]/my_role",
85+
password="pass /*&",
86+
host="localhost",
87+
session_properties={"query_max_run_time": "1d"},
88+
http_headers={"trino": 1},
89+
extra_credential=[
90+
("[email protected]/my_role", "[email protected]/my_role"),
91+
("[email protected]/my_role", "[email protected]/my_role")],
92+
experimental_python_types=True,
93+
client_tags=["1 @& /\"", "sql"],
94+
)),
95+
list(),
96+
dict(
97+
host="localhost",
98+
port=8080,
99+
catalog="system",
100+
user="[email protected]/my_role",
101+
auth=BasicAuthentication("[email protected]/my_role", "pass /*&"),
102+
http_scheme="https",
103+
source="trino-sqlalchemy",
104+
session_properties={"query_max_run_time": "1d"},
105+
http_headers={"trino": 1},
106+
extra_credential=[
107+
("[email protected]/my_role", "[email protected]/my_role"),
108+
("[email protected]/my_role", "[email protected]/my_role")],
63109
experimental_python_types=True,
110+
client_tags=["1 @& /\"", "sql"],
64111
),
65112
),
66113
],

trino/sqlalchemy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
from sqlalchemy.dialects import registry
13+
from .util import _url as URL # noqa
1314

1415
registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect")

trino/sqlalchemy/dialect.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ast import literal_eval
1414
from textwrap import dedent
1515
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
16+
from urllib.parse import unquote_plus, unquote
1617

1718
from sqlalchemy import exc, sql
1819
from sqlalchemy.engine.base import Connection
@@ -80,49 +81,49 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
8081

8182
db_parts = (url.database or "system").split("/")
8283
if len(db_parts) == 1:
83-
kwargs["catalog"] = db_parts[0]
84+
kwargs["catalog"] = unquote_plus(db_parts[0])
8485
elif len(db_parts) == 2:
85-
kwargs["catalog"] = db_parts[0]
86-
kwargs["schema"] = db_parts[1]
86+
kwargs["catalog"] = unquote_plus(db_parts[0])
87+
kwargs["schema"] = unquote_plus(db_parts[1])
8788
else:
8889
raise ValueError(f"Unexpected database format {url.database}")
8990

9091
if url.username:
91-
kwargs["user"] = url.username
92+
kwargs["user"] = unquote(url.username)
9293

9394
if url.password:
9495
if not url.username:
9596
raise ValueError("Username is required when specify password in connection URL")
9697
kwargs["http_scheme"] = "https"
97-
kwargs["auth"] = BasicAuthentication(url.username, url.password)
98+
kwargs["auth"] = BasicAuthentication(unquote(url.username), unquote(url.password))
9899

99100
if "access_token" in url.query:
100101
kwargs["http_scheme"] = "https"
101-
kwargs["auth"] = JWTAuthentication(url.query["access_token"])
102+
kwargs["auth"] = JWTAuthentication(unquote(url.query["access_token"]))
102103

103104
if "cert" and "key" in url.query:
104105
kwargs["http_scheme"] = "https"
105-
kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key'])
106+
kwargs["auth"] = CertificateAuthentication(unquote(url.query['cert']), unquote(url.query['key']))
106107

107108
if "source" in url.query:
108-
kwargs["source"] = url.query["source"]
109+
kwargs["source"] = unquote(url.query["source"])
109110
else:
110111
kwargs["source"] = "trino-sqlalchemy"
111112

112113
if "session_properties" in url.query:
113-
kwargs["session_properties"] = json.loads(url.query["session_properties"])
114+
kwargs["session_properties"] = json.loads(unquote(url.query["session_properties"]))
114115

115116
if "http_headers" in url.query:
116-
kwargs["http_headers"] = json.loads(url.query["http_headers"])
117+
kwargs["http_headers"] = json.loads(unquote(url.query["http_headers"]))
117118

118119
if "extra_credential" in url.query:
119-
kwargs["extra_credential"] = literal_eval(url.query["extra_credential"])
120+
kwargs["extra_credential"] = literal_eval(unquote(url.query["extra_credential"]))
120121

121122
if "client_tags" in url.query:
122-
kwargs["client_tags"] = json.loads(url.query["client_tags"])
123+
kwargs["client_tags"] = json.loads(unquote(url.query["client_tags"]))
123124

124125
if "experimental_python_types" in url.query:
125-
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
126+
kwargs["experimental_python_types"] = json.loads(unquote(url.query["experimental_python_types"]))
126127

127128
return args, kwargs
128129

trino/sqlalchemy/util.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import json
2+
from urllib.parse import quote_plus
3+
4+
from typing import Optional, Dict, List, Union, Tuple
5+
from sqlalchemy import exc
6+
from sqlalchemy.engine.url import _rfc_1738_quote # noqa
7+
8+
9+
def _url(
10+
host: str,
11+
port: Optional[int] = 8080,
12+
user: Optional[str] = None,
13+
password: Optional[str] = None,
14+
catalog: Optional[str] = None,
15+
schema: Optional[str] = None,
16+
source: Optional[str] = "trino-sqlalchemy",
17+
session_properties: Dict[str, str] = None,
18+
http_headers: Dict[str, Union[str, int]] = None,
19+
extra_credential: Optional[List[Tuple[str, str]]] = None,
20+
client_tags: Optional[List[str]] = None,
21+
experimental_python_types: Optional[bool] = None,
22+
access_token: Optional[str] = None,
23+
cert: Optional[str] = None,
24+
key: Optional[str] = None,
25+
) -> str:
26+
"""
27+
Composes a SQLAlchemy connect string from the given database connection
28+
parameters.
29+
Parameters containing special characters (e.g., '@', '%') need to be encoded to be parsed correctly.
30+
"""
31+
32+
trino_url = "trino://"
33+
34+
if user is not None:
35+
trino_url += _rfc_1738_quote(user)
36+
37+
if password is not None:
38+
if user is None:
39+
raise exc.ArgumentError("user must be specified when specifying a password.")
40+
trino_url += f":{_rfc_1738_quote(password)}"
41+
42+
if user is not None:
43+
trino_url += "@"
44+
45+
if not host:
46+
raise exc.ArgumentError("host must be specified.")
47+
48+
trino_url += host
49+
50+
if not port:
51+
raise exc.ArgumentError("port must be specified.")
52+
53+
trino_url += f":{port}"
54+
55+
if catalog is not None:
56+
trino_url += f"/{quote_plus(catalog)}"
57+
58+
if schema is not None:
59+
if catalog is None:
60+
raise exc.ArgumentError("catalog must be specified when specifying a default schema.")
61+
trino_url += f"/{quote_plus(schema)}"
62+
63+
assert source
64+
trino_url += f"?source={quote_plus(source)}"
65+
66+
if session_properties is not None:
67+
trino_url += f"&session_properties={quote_plus(json.dumps(session_properties))}"
68+
69+
if http_headers is not None:
70+
trino_url += f"&http_headers={quote_plus(json.dumps(http_headers))}"
71+
72+
if extra_credential is not None:
73+
trino_url += f"&extra_credential={quote_plus(repr(extra_credential))}"
74+
75+
if client_tags is not None:
76+
trino_url += f"&client_tags={quote_plus(json.dumps(client_tags))}"
77+
78+
if experimental_python_types is not None:
79+
trino_url += f"&experimental_python_types={quote_plus(json.dumps(experimental_python_types))}"
80+
81+
if access_token is not None:
82+
trino_url += f"&access_token={quote_plus(access_token)}"
83+
84+
if cert is not None:
85+
trino_url += f"&cert={quote_plus(cert)}"
86+
87+
if key is not None:
88+
trino_url += f"&key={quote_plus(key)}"
89+
90+
return trino_url

0 commit comments

Comments
 (0)