Skip to content

Escape url parameters in sqlalchemy connection strings #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ Attributes can also be passed in the connection string.

```python
from sqlalchemy import create_engine
from trino.sqlalchemy import URL

engine = create_engine(
'trino://user@localhost:8080/system',
URL(
host="localhost",
port=8080,
catalog="system"
),
connect_args={
"session_properties": {'query_max_run_time': '1d'},
"client_tags": ["tag1", "tag2"],
Expand All @@ -119,6 +124,14 @@ engine = create_engine(
'&experimental_python_types=true'
'&roles={"catalog1": "role1"}'
)

# or using the URL factory method
engine = create_engine(URL(
host="localhost",
port=8080,
client_tags=["tag1", "tag2"],
experimental_python_types=True
))
```

## Authentication mechanisms
Expand Down
175 changes: 153 additions & 22 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,43 @@
from trino.dbapi import Connection
from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect
from trino.transaction import IsolationLevel
from trino.sqlalchemy import URL as trino_url


class TestTrinoDialect:
def setup(self):
self.dialect = TrinoDialect()

@pytest.mark.parametrize(
"url, expected_args, expected_kwargs",
"url, generated_url, expected_args, expected_kwargs",
[
(
make_url("trino://user@localhost"),
make_url(trino_url(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth to keep existing tests and create new ones with URL method next to them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can simply assert that the string being generated is equal to the above string?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

user="user",
host="localhost",
)),
'trino://user@localhost:8080?source=trino-sqlalchemy',
list(),
dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"),
dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"),
),
(
make_url("trino://user@localhost:8080"),
make_url(trino_url(
user="user",
host="localhost",
port=443,
)),
'trino://user@localhost:443?source=trino-sqlalchemy',
list(),
dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"),
),
(
make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
make_url(trino_url(
user="user",
password="pass",
host="localhost",
source="trino-rulez",
)),
'trino://user:***@localhost:8080?source=trino-rulez',
list(),
dict(
host="localhost",
Expand All @@ -42,13 +58,64 @@ def setup(self):
),
),
(
make_url(
'trino://user@localhost:8080?'
'session_properties={"query_max_run_time": "1d"}'
'&http_headers={"trino": 1}'
'&extra_credential=[("a", "b"), ("c", "d")]'
'&client_tags=[1, "sql"]'
'&experimental_python_types=true'),
make_url(trino_url(
user="user",
host="localhost",
cert="/my/path/to/cert",
key="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'?cert=%2Fmy%2Fpath%2Fto%2Fcert'
'&key=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"),
http_scheme="https",
source="trino-sqlalchemy"
),
),
(
make_url(trino_url(
user="user",
host="localhost",
access_token="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'?access_token=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
auth=JWTAuthentication("afdlsdfk%4#'"),
http_scheme="https",
source="trino-sqlalchemy"
),
),
(
make_url(trino_url(
user="user",
host="localhost",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[("a", "b"), ("c", "d")],
client_tags=["1", "sql"],
experimental_python_types=True,
)),
'trino://user@localhost:8080'
'?client_tags=%5B%221%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D'
'&http_headers=%7B%22trino%22%3A+1%7D'
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
Expand All @@ -59,23 +126,87 @@ def setup(self):
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[("a", "b"), ("c", "d")],
client_tags=[1, "sql"],
client_tags=["1", "sql"],
experimental_python_types=True,
),
),
# url encoding
(
make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'),
make_url(trino_url(
user="[email protected]/my_role",
password="pass /*&",
host="localhost",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[
("[email protected]/my_role", "[email protected]/my_role"),
("[email protected]/my_role", "[email protected]/my_role")],
experimental_python_types=True,
client_tags=["1 @& /\"", "sql"],
verify=False,
)),
'trino://user%40test.org%2Fmy_role:***@localhost:8080'
'?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+'
'%22user2%40test.org%2Fmy_role%22%5D%2C+'
'%5B%22user3%40test.org%2Fmy_role%22%2C+'
'%22user36%40test.org%2Fmy_role%22%5D%5D'
'&http_headers=%7B%22trino%22%3A+1%7D'
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
'&source=trino-sqlalchemy'
'&verify=false',
list(),
dict(host="localhost",
port=8080,
catalog="system",
user="user",
roles={"hive": "finance", "system": "analyst"},
source="trino-sqlalchemy"),
dict(
host="localhost",
port=8080,
catalog="system",
user="[email protected]/my_role",
auth=BasicAuthentication("[email protected]/my_role", "pass /*&"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test for other supported auth methods: JWT and Certificate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added both tests.

http_scheme="https",
source="trino-sqlalchemy",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[
("[email protected]/my_role", "[email protected]/my_role"),
("[email protected]/my_role", "[email protected]/my_role")],
experimental_python_types=True,
client_tags=["1 @& /\"", "sql"],
verify=False,
),
),
(
make_url(trino_url(
user="user",
host="localhost",
roles={
"hive": "finance",
"system": "analyst",
}
)),
'trino://user@localhost:8080'
'?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
roles={"hive": "finance", "system": "analyst"},
source="trino-sqlalchemy",
),
),
],
)
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
def test_create_connect_args(
self,
url: URL,
generated_url: str,
expected_args: List[Any],
expected_kwargs: Dict[str, Any]
):
assert repr(url) == generated_url

actual_args, actual_kwargs = self.dialect.create_connect_args(url)

assert actual_args == expected_args
Expand Down
1 change: 1 addition & 0 deletions trino/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.dialects import registry
from .util import _url as URL # noqa

registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect")
33 changes: 19 additions & 14 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from ast import literal_eval
from textwrap import dedent
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from urllib.parse import unquote_plus

from sqlalchemy import exc, sql
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -80,49 +80,54 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any

db_parts = (url.database or "system").split("/")
if len(db_parts) == 1:
kwargs["catalog"] = db_parts[0]
kwargs["catalog"] = unquote_plus(db_parts[0])
elif len(db_parts) == 2:
kwargs["catalog"] = db_parts[0]
kwargs["schema"] = db_parts[1]
kwargs["catalog"] = unquote_plus(db_parts[0])
kwargs["schema"] = unquote_plus(db_parts[1])
else:
raise ValueError(f"Unexpected database format {url.database}")

if url.username:
kwargs["user"] = url.username
kwargs["user"] = unquote_plus(url.username)

if url.password:
if not url.username:
raise ValueError("Username is required when specify password in connection URL")
kwargs["http_scheme"] = "https"
kwargs["auth"] = BasicAuthentication(url.username, url.password)
kwargs["auth"] = BasicAuthentication(unquote_plus(url.username), unquote_plus(url.password))

if "access_token" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = JWTAuthentication(url.query["access_token"])
kwargs["auth"] = JWTAuthentication(unquote_plus(url.query["access_token"]))

if "cert" and "key" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key'])
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))

if "source" in url.query:
kwargs["source"] = url.query["source"]
kwargs["source"] = unquote_plus(url.query["source"])
else:
kwargs["source"] = "trino-sqlalchemy"

if "session_properties" in url.query:
kwargs["session_properties"] = json.loads(url.query["session_properties"])
kwargs["session_properties"] = json.loads(unquote_plus(url.query["session_properties"]))

if "http_headers" in url.query:
kwargs["http_headers"] = json.loads(url.query["http_headers"])
kwargs["http_headers"] = json.loads(unquote_plus(url.query["http_headers"]))

if "extra_credential" in url.query:
kwargs["extra_credential"] = literal_eval(url.query["extra_credential"])
kwargs["extra_credential"] = [
tuple(extra_credential) for extra_credential in json.loads(unquote_plus(url.query["extra_credential"]))
]

if "client_tags" in url.query:
kwargs["client_tags"] = json.loads(url.query["client_tags"])
kwargs["client_tags"] = json.loads(unquote_plus(url.query["client_tags"]))

if "experimental_python_types" in url.query:
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
kwargs["experimental_python_types"] = json.loads(unquote_plus(url.query["experimental_python_types"]))

if "verify" in url.query:
kwargs["verify"] = json.loads(unquote_plus(url.query["verify"]))

if "roles" in url.query:
kwargs["roles"] = json.loads(url.query["roles"])
Expand Down
Loading