Skip to content

Commit e9fa66c

Browse files
mdesmetebyhr
authored andcommitted
Move retry logic to client module
1 parent 951ad82 commit e9fa66c

File tree

4 files changed

+119
-139
lines changed

4 files changed

+119
-139
lines changed

tests/unit/test_client.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
SERVER_ADDRESS
2929
from trino import constants
3030
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
31-
from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession
31+
from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \
32+
_RetryWithExponentialBackoff
3233

3334

3435
@mock.patch("trino.client.TrinoRequest.http")
@@ -947,3 +948,57 @@ def json(self):
947948

948949
# Validate the result is an instance of TrinoResult
949950
assert isinstance(result, TrinoResult)
951+
952+
953+
def test_delay_exponential_without_jitter():
954+
max_delay = 1200.0
955+
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)
956+
results = [
957+
10.0,
958+
20.0,
959+
40.0,
960+
80.0,
961+
160.0,
962+
320.0,
963+
640.0,
964+
max_delay, # rather than 1280.0
965+
max_delay, # rather than 2560.0
966+
]
967+
for i, result in enumerate(results, start=1):
968+
assert get_delay(i) == result
969+
970+
971+
def test_delay_exponential_with_jitter():
972+
max_delay = 120.0
973+
get_delay = _DelayExponential(base=10, jitter=False, max_delay=max_delay)
974+
for i in range(10):
975+
assert get_delay(i) <= max_delay
976+
977+
978+
class SomeException(Exception):
979+
pass
980+
981+
982+
def test_retry_with():
983+
max_attempts = 3
984+
with_retry = _retry_with(
985+
handle_retry=_RetryWithExponentialBackoff(),
986+
handled_exceptions=[SomeException],
987+
conditions={},
988+
max_attempts=max_attempts,
989+
)
990+
991+
class FailerUntil(object):
992+
def __init__(self, until=1):
993+
self.attempt = 0
994+
self._until = until
995+
996+
def __call__(self):
997+
self.attempt += 1
998+
if self.attempt > self._until:
999+
return
1000+
raise SomeException(self.attempt)
1001+
1002+
with_retry(FailerUntil(2).__call__)()
1003+
with pytest.raises(SomeException):
1004+
with_retry(FailerUntil(3).__call__)()

tests/unit/test_exceptions.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

trino/client.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@
3434
"""
3535

3636
import copy
37+
import functools
3738
import os
39+
import random
3840
import re
3941
import threading
42+
import time
4043
import urllib.parse
4144
from datetime import datetime, timedelta, timezone
4245
from decimal import Decimal
@@ -227,6 +230,34 @@ def __repr__(self):
227230
)
228231

229232

233+
class _DelayExponential(object):
234+
def __init__(
235+
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
236+
):
237+
self._base = base
238+
self._exponent = exponent
239+
self._jitter = jitter
240+
self._max_delay = max_delay
241+
242+
def __call__(self, attempt):
243+
delay = float(self._base) * (self._exponent ** attempt)
244+
if self._jitter:
245+
delay *= random.random()
246+
delay = min(float(self._max_delay), delay)
247+
return delay
248+
249+
250+
class _RetryWithExponentialBackoff(object):
251+
def __init__(
252+
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
253+
):
254+
self._get_delay = _DelayExponential(base, exponent, jitter, max_delay)
255+
256+
def retry(self, func, args, kwargs, err, attempt):
257+
delay = self._get_delay(attempt)
258+
time.sleep(delay)
259+
260+
230261
class TrinoRequest(object):
231262
"""
232263
Manage the HTTP requests of a Trino query.
@@ -286,7 +317,7 @@ def __init__(
286317
redirect_handler: Any = None,
287318
max_attempts: int = MAX_ATTEMPTS,
288319
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
289-
handle_retry=exceptions.RetryWithExponentialBackoff(),
320+
handle_retry=_RetryWithExponentialBackoff(),
290321
verify: bool = True,
291322
) -> None:
292323
self._client_session = client_session
@@ -383,9 +414,9 @@ def max_attempts(self, value) -> None:
383414
self._delete = self._http_session.delete
384415
return
385416

386-
with_retry = exceptions.retry_with(
417+
with_retry = _retry_with(
387418
self._handle_retry,
388-
exceptions=self._exceptions,
419+
handled_exceptions=self._exceptions,
389420
conditions=(
390421
# need retry when there is no exception but the status code is 502, 503, or 504
391422
lambda response: getattr(response, "status_code", None)
@@ -779,3 +810,32 @@ def cancelled(self) -> bool:
779810
@property
780811
def response_headers(self):
781812
return self._response_headers
813+
814+
815+
def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
816+
def wrapper(func):
817+
@functools.wraps(func)
818+
def decorated(*args, **kwargs):
819+
error = None
820+
result = None
821+
for attempt in range(1, max_attempts + 1):
822+
try:
823+
result = func(*args, **kwargs)
824+
if any(guard(result) for guard in conditions):
825+
handle_retry.retry(func, args, kwargs, None, attempt)
826+
continue
827+
return result
828+
except Exception as err:
829+
error = err
830+
if any(isinstance(err, exc) for exc in handled_exceptions):
831+
handle_retry.retry(func, args, kwargs, err, attempt)
832+
continue
833+
break
834+
logger.info("failed after %s attempts", attempt)
835+
if error is not None:
836+
raise error
837+
return result
838+
839+
return decorated
840+
841+
return wrapper

trino/exceptions.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
"""
1717

1818

19-
import functools
20-
import random
21-
import time
22-
2319
import trino.logging
2420

2521
logger = trino.logging.get_logger(__name__)
@@ -120,63 +116,6 @@ class TrinoUserError(TrinoQueryError):
120116
pass
121117

122118

123-
def retry_with(handle_retry, exceptions, conditions, max_attempts):
124-
def wrapper(func):
125-
@functools.wraps(func)
126-
def decorated(*args, **kwargs):
127-
error = None
128-
result = None
129-
for attempt in range(1, max_attempts + 1):
130-
try:
131-
result = func(*args, **kwargs)
132-
if any(guard(result) for guard in conditions):
133-
handle_retry.retry(func, args, kwargs, None, attempt)
134-
continue
135-
return result
136-
except Exception as err:
137-
error = err
138-
if any(isinstance(err, exc) for exc in exceptions):
139-
handle_retry.retry(func, args, kwargs, err, attempt)
140-
continue
141-
break
142-
logger.info("failed after %s attempts", attempt)
143-
if error is not None:
144-
raise error
145-
return result
146-
147-
return decorated
148-
149-
return wrapper
150-
151-
152-
class DelayExponential(object):
153-
def __init__(
154-
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
155-
):
156-
self._base = base
157-
self._exponent = exponent
158-
self._jitter = jitter
159-
self._max_delay = max_delay
160-
161-
def __call__(self, attempt):
162-
delay = float(self._base) * (self._exponent ** attempt)
163-
if self._jitter:
164-
delay *= random.random()
165-
delay = min(float(self._max_delay), delay)
166-
return delay
167-
168-
169-
class RetryWithExponentialBackoff(object):
170-
def __init__(
171-
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
172-
):
173-
self._get_delay = DelayExponential(base, exponent, jitter, max_delay)
174-
175-
def retry(self, func, args, kwargs, err, attempt):
176-
delay = self._get_delay(attempt)
177-
time.sleep(delay)
178-
179-
180119
# PEP 249
181120
class Error(Exception):
182121
pass

0 commit comments

Comments
 (0)