Skip to content

Commit 2539a2c

Browse files
committed
Execute should block until at least one row is received
1 parent 3f07f68 commit 2539a2c

File tree

5 files changed

+51
-21
lines changed

5 files changed

+51
-21
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def test_execute_many_without_params(trino_connection):
153153
cur = trino_connection.cursor()
154154
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
155155
cur.fetchall()
156-
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
157156
with pytest.raises(TrinoUserError) as e:
157+
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
158158
cur.fetchall()
159159
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)
160160

@@ -883,13 +883,12 @@ def test_transaction_autocommit(trino_connection_in_autocommit):
883883
with trino_connection_in_autocommit as connection:
884884
connection.start_transaction()
885885
cur = connection.cursor()
886-
cur.execute(
887-
"""
888-
CREATE TABLE memory.default.nation
889-
AS SELECT * from tpch.tiny.nation
890-
""")
891-
892886
with pytest.raises(TrinoUserError) as transaction_error:
887+
cur.execute(
888+
"""
889+
CREATE TABLE memory.default.nation
890+
AS SELECT * from tpch.tiny.nation
891+
""")
893892
cur.fetchall()
894893
assert "Catalog only supports writes using autocommit: memory" \
895894
in str(transaction_error.value)

tests/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def sample_post_response_data():
3737
"""
3838

3939
yield {
40-
"nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1",
40+
"nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1",
4141
"id": "20210817_140827_00000_arvdv",
4242
"taskDownloadUris": [],
43-
"infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv",
43+
"infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv",
4444
"stats": {
4545
"scheduled": False,
4646
"runningSplits": 0,

tests/unit/test_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -892,9 +892,7 @@ def test_trino_result_response_headers():
892892
'X-Trino-Fake-2': 'two',
893893
})
894894

895-
result = TrinoResult(
896-
query=mock_trino_query,
897-
)
895+
result = TrinoResult(query=mock_trino_query, rows=[])
898896
assert result.response_headers == mock_trino_query.response_headers
899897

900898

tests/unit/test_dbapi.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,28 @@ def test_http_session_is_defaulted_when_not_specified(mock_client):
5151

5252

5353
@httprettified
54-
def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
54+
def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data):
5555
token = str(uuid.uuid4())
5656
challenge_id = str(uuid.uuid4())
5757

5858
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
5959
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
6060

6161
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
62+
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)
6263

63-
# bind post statement
64+
# bind post statement to submit query
6465
httpretty.register_uri(
6566
method=httpretty.POST,
6667
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
6768
body=post_statement_callback)
6869

70+
# bind get statement for result retrieval
71+
httpretty.register_uri(
72+
method=httpretty.GET,
73+
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
74+
body=get_statement_callback)
75+
6976
# bind get token
7077
get_token_callback = GetTokenCallback(token_server, token)
7178
httpretty.register_uri(
@@ -108,21 +115,29 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
108115

109116

110117
@httprettified
111-
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data):
118+
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
119+
sample_get_response_data):
112120
token = str(uuid.uuid4())
113121
challenge_id = str(uuid.uuid4())
114122

115123
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
116124
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
117125

118126
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
127+
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)
119128

120-
# bind post statement
129+
# bind post statement to submit query
121130
httpretty.register_uri(
122131
method=httpretty.POST,
123132
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
124133
body=post_statement_callback)
125134

135+
# bind get statement for result retrieval
136+
httpretty.register_uri(
137+
method=httpretty.GET,
138+
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
139+
body=get_statement_callback)
140+
126141
# bind get token
127142
get_token_callback = GetTokenCallback(token_server, token)
128143
httpretty.register_uri(
@@ -166,21 +181,28 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
166181

167182

168183
@httprettified
169-
def test_token_retrieved_once_when_multithreaded(sample_post_response_data):
184+
def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data):
170185
token = str(uuid.uuid4())
171186
challenge_id = str(uuid.uuid4())
172187

173188
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
174189
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
175190

176191
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
192+
get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data)
177193

178-
# bind post statement
194+
# bind post statement to submit query
179195
httpretty.register_uri(
180196
method=httpretty.POST,
181197
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
182198
body=post_statement_callback)
183199

200+
# bind get statement for result retrieval
201+
httpretty.register_uri(
202+
method=httpretty.GET,
203+
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
204+
body=get_statement_callback)
205+
184206
# bind get token
185207
get_token_callback = GetTokenCallback(token_server, token)
186208
httpretty.register_uri(

trino/client.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,12 +592,20 @@ class TrinoResult(object):
592592
https://docs.python.org/3/library/stdtypes.html#generator-types
593593
"""
594594

595-
def __init__(self, query, rows=None):
595+
def __init__(self, query, rows: List[Any]):
596596
self._query = query
597597
# Initial rows from the first POST request
598598
self._rows = rows
599599
self._rownumber = 0
600600

601+
@property
602+
def rows(self):
603+
return self._rows
604+
605+
@rows.setter
606+
def rows(self, rows):
607+
self._rows = rows
608+
601609
@property
602610
def rownumber(self) -> int:
603611
return self._rownumber
@@ -650,7 +658,7 @@ def columns(self):
650658
while not self._columns and not self.finished and not self.cancelled:
651659
# Columns are not returned immediately after query is submitted.
652660
# Continue fetching data until columns information is available and push fetched rows into buffer.
653-
self._result._rows += self.fetch()
661+
self._result.rows += self.fetch()
654662
return self._columns
655663

656664
@property
@@ -695,8 +703,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
695703
self._finished = True
696704

697705
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
698-
699706
self._result = TrinoResult(self, rows)
707+
708+
# Execute should block until at least one row is received
709+
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
710+
self._result.rows += self.fetch()
700711
return self._result
701712

702713
def _update_state(self, status):

0 commit comments

Comments
 (0)