Skip to content

Commit cbda49c

Browse files
mdesmethashhar
authored andcommitted
Escape url parameters in sqlalchemy connection strings
1 parent b4bd746 commit cbda49c

File tree

5 files changed

+286
-37
lines changed

5 files changed

+286
-37
lines changed

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,14 @@ Attributes can also be passed in the connection string.
100100

101101
```python
102102
from sqlalchemy import create_engine
103+
from trino.sqlalchemy import URL
103104

104105
engine = create_engine(
105-
'trino://user@localhost:8080/system',
106+
URL(
107+
host="localhost",
108+
port=8080,
109+
catalog="system"
110+
),
106111
connect_args={
107112
"session_properties": {'query_max_run_time': '1d'},
108113
"client_tags": ["tag1", "tag2"],
@@ -119,6 +124,14 @@ engine = create_engine(
119124
'&experimental_python_types=true'
120125
'&roles={"catalog1": "role1"}'
121126
)
127+
128+
# or using the URL factory method
129+
engine = create_engine(URL(
130+
host="localhost",
131+
port=8080,
132+
client_tags=["tag1", "tag2"],
133+
experimental_python_types=True
134+
))
122135
```
123136

124137
## Authentication mechanisms

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 153 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,43 @@
99
from trino.dbapi import Connection
1010
from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect
1111
from trino.transaction import IsolationLevel
12+
from trino.sqlalchemy import URL as trino_url
1213

1314

1415
class TestTrinoDialect:
1516
def setup(self):
1617
self.dialect = TrinoDialect()
1718

1819
@pytest.mark.parametrize(
19-
"url, expected_args, expected_kwargs",
20+
"url, generated_url, expected_args, expected_kwargs",
2021
[
2122
(
22-
make_url("trino://user@localhost"),
23+
make_url(trino_url(
24+
user="user",
25+
host="localhost",
26+
)),
27+
'trino://user@localhost:8080?source=trino-sqlalchemy',
2328
list(),
24-
dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"),
29+
dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"),
2530
),
2631
(
27-
make_url("trino://user@localhost:8080"),
32+
make_url(trino_url(
33+
user="user",
34+
host="localhost",
35+
port=443,
36+
)),
37+
'trino://user@localhost:443?source=trino-sqlalchemy',
2838
list(),
29-
dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
39+
dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"),
3040
),
3141
(
32-
make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
42+
make_url(trino_url(
43+
user="user",
44+
password="pass",
45+
host="localhost",
46+
source="trino-rulez",
47+
)),
48+
'trino://user:***@localhost:8080?source=trino-rulez',
3349
list(),
3450
dict(
3551
host="localhost",
@@ -42,13 +58,64 @@ def setup(self):
4258
),
4359
),
4460
(
45-
make_url(
46-
'trino://user@localhost:8080?'
47-
'session_properties={"query_max_run_time": "1d"}'
48-
'&http_headers={"trino": 1}'
49-
'&extra_credential=[("a", "b"), ("c", "d")]'
50-
'&client_tags=[1, "sql"]'
51-
'&experimental_python_types=true'),
61+
make_url(trino_url(
62+
user="user",
63+
host="localhost",
64+
cert="/my/path/to/cert",
65+
key="afdlsdfk%4#'",
66+
)),
67+
'trino://user@localhost:8080'
68+
'?cert=%2Fmy%2Fpath%2Fto%2Fcert'
69+
'&key=afdlsdfk%254%23%27'
70+
'&source=trino-sqlalchemy',
71+
list(),
72+
dict(
73+
host="localhost",
74+
port=8080,
75+
catalog="system",
76+
user="user",
77+
auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"),
78+
http_scheme="https",
79+
source="trino-sqlalchemy"
80+
),
81+
),
82+
(
83+
make_url(trino_url(
84+
user="user",
85+
host="localhost",
86+
access_token="afdlsdfk%4#'",
87+
)),
88+
'trino://user@localhost:8080'
89+
'?access_token=afdlsdfk%254%23%27'
90+
'&source=trino-sqlalchemy',
91+
list(),
92+
dict(
93+
host="localhost",
94+
port=8080,
95+
catalog="system",
96+
user="user",
97+
auth=JWTAuthentication("afdlsdfk%4#'"),
98+
http_scheme="https",
99+
source="trino-sqlalchemy"
100+
),
101+
),
102+
(
103+
make_url(trino_url(
104+
user="user",
105+
host="localhost",
106+
session_properties={"query_max_run_time": "1d"},
107+
http_headers={"trino": 1},
108+
extra_credential=[("a", "b"), ("c", "d")],
109+
client_tags=["1", "sql"],
110+
experimental_python_types=True,
111+
)),
112+
'trino://user@localhost:8080'
113+
'?client_tags=%5B%221%22%2C+%22sql%22%5D'
114+
'&experimental_python_types=true'
115+
'&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D'
116+
'&http_headers=%7B%22trino%22%3A+1%7D'
117+
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
118+
'&source=trino-sqlalchemy',
52119
list(),
53120
dict(
54121
host="localhost",
@@ -59,23 +126,87 @@ def setup(self):
59126
session_properties={"query_max_run_time": "1d"},
60127
http_headers={"trino": 1},
61128
extra_credential=[("a", "b"), ("c", "d")],
62-
client_tags=[1, "sql"],
129+
client_tags=["1", "sql"],
63130
experimental_python_types=True,
64131
),
65132
),
133+
# url encoding
66134
(
67-
make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'),
135+
make_url(trino_url(
136+
user="[email protected]/my_role",
137+
password="pass /*&",
138+
host="localhost",
139+
session_properties={"query_max_run_time": "1d"},
140+
http_headers={"trino": 1},
141+
extra_credential=[
142+
("[email protected]/my_role", "[email protected]/my_role"),
143+
("[email protected]/my_role", "[email protected]/my_role")],
144+
experimental_python_types=True,
145+
client_tags=["1 @& /\"", "sql"],
146+
verify=False,
147+
)),
148+
'trino://user%40test.org%2Fmy_role:***@localhost:8080'
149+
'?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D'
150+
'&experimental_python_types=true'
151+
'&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+'
152+
'%22user2%40test.org%2Fmy_role%22%5D%2C+'
153+
'%5B%22user3%40test.org%2Fmy_role%22%2C+'
154+
'%22user36%40test.org%2Fmy_role%22%5D%5D'
155+
'&http_headers=%7B%22trino%22%3A+1%7D'
156+
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
157+
'&source=trino-sqlalchemy'
158+
'&verify=false',
68159
list(),
69-
dict(host="localhost",
70-
port=8080,
71-
catalog="system",
72-
user="user",
73-
roles={"hive": "finance", "system": "analyst"},
74-
source="trino-sqlalchemy"),
160+
dict(
161+
host="localhost",
162+
port=8080,
163+
catalog="system",
164+
user="[email protected]/my_role",
165+
auth=BasicAuthentication("[email protected]/my_role", "pass /*&"),
166+
http_scheme="https",
167+
source="trino-sqlalchemy",
168+
session_properties={"query_max_run_time": "1d"},
169+
http_headers={"trino": 1},
170+
extra_credential=[
171+
("[email protected]/my_role", "[email protected]/my_role"),
172+
("[email protected]/my_role", "[email protected]/my_role")],
173+
experimental_python_types=True,
174+
client_tags=["1 @& /\"", "sql"],
175+
verify=False,
176+
),
177+
),
178+
(
179+
make_url(trino_url(
180+
user="user",
181+
host="localhost",
182+
roles={
183+
"hive": "finance",
184+
"system": "analyst",
185+
}
186+
)),
187+
'trino://user@localhost:8080'
188+
'?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy',
189+
list(),
190+
dict(
191+
host="localhost",
192+
port=8080,
193+
catalog="system",
194+
user="user",
195+
roles={"hive": "finance", "system": "analyst"},
196+
source="trino-sqlalchemy",
197+
),
75198
),
76199
],
77200
)
78-
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
201+
def test_create_connect_args(
202+
self,
203+
url: URL,
204+
generated_url: str,
205+
expected_args: List[Any],
206+
expected_kwargs: Dict[str, Any]
207+
):
208+
assert repr(url) == generated_url
209+
79210
actual_args, actual_kwargs = self.dialect.create_connect_args(url)
80211

81212
assert actual_args == expected_args

trino/sqlalchemy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
from sqlalchemy.dialects import registry
13+
from .util import _url as URL # noqa
1314

1415
registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect")

trino/sqlalchemy/dialect.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import json
13-
from ast import literal_eval
1413
from textwrap import dedent
1514
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
15+
from urllib.parse import unquote_plus
1616

1717
from sqlalchemy import exc, sql
1818
from sqlalchemy.engine.base import Connection
@@ -80,49 +80,54 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
8080

8181
db_parts = (url.database or "system").split("/")
8282
if len(db_parts) == 1:
83-
kwargs["catalog"] = db_parts[0]
83+
kwargs["catalog"] = unquote_plus(db_parts[0])
8484
elif len(db_parts) == 2:
85-
kwargs["catalog"] = db_parts[0]
86-
kwargs["schema"] = db_parts[1]
85+
kwargs["catalog"] = unquote_plus(db_parts[0])
86+
kwargs["schema"] = unquote_plus(db_parts[1])
8787
else:
8888
raise ValueError(f"Unexpected database format {url.database}")
8989

9090
if url.username:
91-
kwargs["user"] = url.username
91+
kwargs["user"] = unquote_plus(url.username)
9292

9393
if url.password:
9494
if not url.username:
9595
raise ValueError("Username is required when specify password in connection URL")
9696
kwargs["http_scheme"] = "https"
97-
kwargs["auth"] = BasicAuthentication(url.username, url.password)
97+
kwargs["auth"] = BasicAuthentication(unquote_plus(url.username), unquote_plus(url.password))
9898

9999
if "access_token" in url.query:
100100
kwargs["http_scheme"] = "https"
101-
kwargs["auth"] = JWTAuthentication(url.query["access_token"])
101+
kwargs["auth"] = JWTAuthentication(unquote_plus(url.query["access_token"]))
102102

103103
if "cert" and "key" in url.query:
104104
kwargs["http_scheme"] = "https"
105-
kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key'])
105+
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))
106106

107107
if "source" in url.query:
108-
kwargs["source"] = url.query["source"]
108+
kwargs["source"] = unquote_plus(url.query["source"])
109109
else:
110110
kwargs["source"] = "trino-sqlalchemy"
111111

112112
if "session_properties" in url.query:
113-
kwargs["session_properties"] = json.loads(url.query["session_properties"])
113+
kwargs["session_properties"] = json.loads(unquote_plus(url.query["session_properties"]))
114114

115115
if "http_headers" in url.query:
116-
kwargs["http_headers"] = json.loads(url.query["http_headers"])
116+
kwargs["http_headers"] = json.loads(unquote_plus(url.query["http_headers"]))
117117

118118
if "extra_credential" in url.query:
119-
kwargs["extra_credential"] = literal_eval(url.query["extra_credential"])
119+
kwargs["extra_credential"] = [
120+
tuple(extra_credential) for extra_credential in json.loads(unquote_plus(url.query["extra_credential"]))
121+
]
120122

121123
if "client_tags" in url.query:
122-
kwargs["client_tags"] = json.loads(url.query["client_tags"])
124+
kwargs["client_tags"] = json.loads(unquote_plus(url.query["client_tags"]))
123125

124126
if "experimental_python_types" in url.query:
125-
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
127+
kwargs["experimental_python_types"] = json.loads(unquote_plus(url.query["experimental_python_types"]))
128+
129+
if "verify" in url.query:
130+
kwargs["verify"] = json.loads(unquote_plus(url.query["verify"]))
126131

127132
if "roles" in url.query:
128133
kwargs["roles"] = json.loads(url.query["roles"])

0 commit comments

Comments
 (0)