Skip to content

Commit ad0d76b

Browse files
ulisesojedaebyhr
authored andcommitted
Add support for SET ROLE statement
1 parent 748d197 commit ad0d76b

File tree

4 files changed

+125
-1
lines changed

4 files changed

+125
-1
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,3 +1034,37 @@ def test_use_catalog(run_trino):
10341034
result = cur.fetchall()
10351035
assert result[0][0] == 'tpch'
10361036
assert result[0][1] == 'sf1'
1037+
1038+
1039+
@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
1040+
def test_set_role_trino_higher_351(run_trino):
1041+
_, host, port = run_trino
1042+
1043+
trino_connection = trino.dbapi.Connection(
1044+
host=host, port=port, user="test", catalog="tpch"
1045+
)
1046+
cur = trino_connection.cursor()
1047+
cur.execute('SHOW TABLES FROM information_schema')
1048+
cur.fetchall()
1049+
assert cur._request._client_session.role is None
1050+
1051+
cur.execute("SET ROLE ALL")
1052+
cur.fetchall()
1053+
assert cur._request._client_session.role == "system=ALL"
1054+
1055+
1056+
@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog")
1057+
def test_set_role_trino_351(run_trino):
1058+
_, host, port = run_trino
1059+
1060+
trino_connection = trino.dbapi.Connection(
1061+
host=host, port=port, user="test", catalog="tpch"
1062+
)
1063+
cur = trino_connection.cursor()
1064+
cur.execute('SHOW TABLES FROM information_schema')
1065+
cur.fetchall()
1066+
assert cur._request._client_session.role is None
1067+
1068+
cur.execute("SET ROLE ALL")
1069+
cur.fetchall()
1070+
assert cur._request._client_session.role == "tpch=ALL"

tests/unit/test_client.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ def assert_headers(headers):
9292
assert headers[constants.HEADER_SOURCE] == source
9393
assert headers[constants.HEADER_USER] == user
9494
assert headers[constants.HEADER_SESSION] == ""
95+
assert headers[constants.HEADER_ROLE] is None
9596
assert headers[accept_encoding_header] == accept_encoding_value
9697
assert headers[client_info_header] == client_info_value
97-
assert len(headers.keys()) == 8
98+
assert len(headers.keys()) == 9
9899

99100
req.post("URL")
100101
_, post_kwargs = post.call_args
@@ -1002,3 +1003,71 @@ def __call__(self):
10021003
with_retry(FailerUntil(2).__call__)()
10031004
with pytest.raises(SomeException):
10041005
with_retry(FailerUntil(3).__call__)()
1006+
1007+
1008+
def assert_headers_with_role(headers, role):
1009+
assert headers[constants.HEADER_USER] == "test_user"
1010+
assert headers[constants.HEADER_ROLE] == role
1011+
assert len(headers.keys()) == 7
1012+
1013+
1014+
def test_request_headers_role_hive_all(mock_get_and_post):
1015+
get, post = mock_get_and_post
1016+
req = TrinoRequest(
1017+
host="coordinator",
1018+
port=8080,
1019+
client_session=ClientSession(
1020+
user="test_user",
1021+
role="hive=ALL",
1022+
),
1023+
)
1024+
1025+
req.post("URL")
1026+
_, post_kwargs = post.call_args
1027+
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
1028+
1029+
req.get("URL")
1030+
_, get_kwargs = get.call_args
1031+
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
1032+
1033+
1034+
def test_request_headers_role_admin(mock_get_and_post):
1035+
get, post = mock_get_and_post
1036+
1037+
req = TrinoRequest(
1038+
host="coordinator",
1039+
port=8080,
1040+
client_session=ClientSession(
1041+
user="test_user",
1042+
role="admin",
1043+
),
1044+
)
1045+
1046+
req.post("URL")
1047+
_, post_kwargs = post.call_args
1048+
assert_headers_with_role(post_kwargs["headers"], "admin")
1049+
1050+
req.get("URL")
1051+
_, get_kwargs = get.call_args
1052+
assert_headers_with_role(post_kwargs["headers"], "admin")
1053+
1054+
1055+
def test_request_headers_role_empty(mock_get_and_post):
1056+
get, post = mock_get_and_post
1057+
1058+
req = TrinoRequest(
1059+
host="coordinator",
1060+
port=8080,
1061+
client_session=ClientSession(
1062+
user="test_user",
1063+
role="",
1064+
),
1065+
)
1066+
1067+
req.post("URL")
1068+
_, post_kwargs = post.call_args
1069+
assert_headers_with_role(post_kwargs["headers"], "")
1070+
1071+
req.get("URL")
1072+
_, get_kwargs = get.call_args
1073+
assert_headers_with_role(post_kwargs["headers"], "")

trino/client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class ClientSession(object):
9898
:param extra_credential: extra credentials. as list of ``(key, value)``
9999
tuples.
100100
:param client_tags: Client tags as list of strings.
101+
:param role: role for the current session. Some connectors do not
102+
support role management. See connector documentation for more details.
101103
"""
102104

103105
def __init__(
@@ -111,6 +113,7 @@ def __init__(
111113
transaction_id: str = None,
112114
extra_credential: List[Tuple[str, str]] = None,
113115
client_tags: List[str] = None,
116+
role: str = None,
114117
):
115118
self._user = user
116119
self._catalog = catalog
@@ -121,6 +124,7 @@ def __init__(
121124
self._transaction_id = transaction_id
122125
self._extra_credential = extra_credential
123126
self._client_tags = client_tags
127+
self._role = role
124128
self._object_lock = threading.Lock()
125129

126130
@property
@@ -192,6 +196,16 @@ def __setstate__(self, state):
192196
self.__dict__.update(state)
193197
self._object_lock = threading.Lock()
194198

199+
@property
200+
def role(self):
201+
with self._object_lock:
202+
return self._role
203+
204+
@role.setter
205+
def role(self, role):
206+
with self._object_lock:
207+
self._role = role
208+
195209

196210
def get_header_values(headers, header):
197211
return [val.strip() for val in headers[header].split(",")]
@@ -368,6 +382,7 @@ def http_headers(self) -> Dict[str, str]:
368382
headers[constants.HEADER_SCHEMA] = self._client_session.schema
369383
headers[constants.HEADER_SOURCE] = self._client_session.source
370384
headers[constants.HEADER_USER] = self._client_session.user
385+
headers[constants.HEADER_ROLE] = self._client_session.role
371386
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
372387
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
373388

@@ -538,6 +553,9 @@ def process(self, http_response) -> TrinoStatus:
538553
if constants.HEADER_SET_SCHEMA in http_response.headers:
539554
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]
540555

556+
if constants.HEADER_SET_ROLE in http_response.headers:
557+
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]
558+
541559
self._next_uri = response.get("nextUri")
542560

543561
return TrinoStatus(

trino/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
HEADER_SET_SESSION = "X-Trino-Set-Session"
4040
HEADER_CLEAR_SESSION = "X-Trino-Clear-Session"
4141

42+
HEADER_ROLE = "X-Trino-Role"
43+
HEADER_SET_ROLE = "X-Trino-Set-Role"
44+
4245
HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id"
4346
HEADER_TRANSACTION = "X-Trino-Transaction-Id"
4447

0 commit comments

Comments
 (0)