1
1
import asyncio
2
2
import enum
3
3
import functools
4
+ import ssl
4
5
import os
5
6
from typing import Optional , Union
6
7
7
8
from .api import Api
9
+ from .const import Transport
8
10
from .exceptions import TarantoolDatabaseError , \
9
- ErrorCode , TarantoolError
11
+ ErrorCode , TarantoolError , SSLError
10
12
from .iproto import protocol
11
13
from .log import logger
12
14
from .stream import Stream
@@ -27,11 +29,13 @@ class ConnectionState(enum.IntEnum):
27
29
28
30
class Connection (Api ):
29
31
__slots__ = (
30
- '_host' , '_port' , '_username' , '_password' ,
31
- '_fetch_schema' , '_auto_refetch_schema' , '_initial_read_buffer_size' ,
32
- '_encoding' , '_connect_timeout' , '_reconnect_timeout' ,
33
- '_request_timeout' , '_ping_timeout' , '_loop' , '_state' , '_state_prev' ,
34
- '_transport' , '_protocol' ,
32
+ '_host' , '_port' , '_parameter_transport' , '_ssl_key_file' ,
33
+ '_ssl_cert_file' , '_ssl_ca_file' , '_ssl_ciphers' ,
34
+ '_username' , '_password' , '_fetch_schema' ,
35
+ '_auto_refetch_schema' , '_initial_read_buffer_size' ,
36
+ '_encoding' , '_connect_timeout' , '_ssl_handshake_timeout' ,
37
+ '_reconnect_timeout' , '_request_timeout' , '_ping_timeout' ,
38
+ '_loop' , '_state' , '_state_prev' , '_transport' , '_protocol' ,
35
39
'_disconnect_waiter' , '_reconnect_task' ,
36
40
'_connect_lock' , '_disconnect_lock' ,
37
41
'_ping_task' , '__create_task'
@@ -40,11 +44,17 @@ class Connection(Api):
40
44
def __init__ (self , * ,
41
45
host : str = '127.0.0.1' ,
42
46
port : Union [int , str ] = 3301 ,
47
+ transport : Optional [Transport ] = Transport .DEFAULT ,
48
+ ssl_key_file : Optional [str ] = None ,
49
+ ssl_cert_file : Optional [str ] = None ,
50
+ ssl_ca_file : Optional [str ] = None ,
51
+ ssl_ciphers : Optional [str ] = None ,
43
52
username : Optional [str ] = None ,
44
53
password : Optional [str ] = None ,
45
54
fetch_schema : bool = True ,
46
55
auto_refetch_schema : bool = True ,
47
56
connect_timeout : float = 3. ,
57
+ ssl_handshake_timeout : float = 3. ,
48
58
request_timeout : float = - 1. ,
49
59
reconnect_timeout : float = 1. / 3. ,
50
60
ping_timeout : float = 5. ,
@@ -78,6 +88,22 @@ def __init__(self, *,
78
88
:param port:
79
89
Tarantool port
80
90
(pass ``/path/to/sockfile`` to connect ot unix socket)
91
+ :param transport:
92
+ This parameter can be used to configure traffic encryption.
93
+ Pass ``asynctnt.Transport.SSL`` value to enable SSL
94
+ encryption (by default there is no encryption)
95
+ :param ssl_key_file:
96
+ A path to a private SSL key file.
97
+ Optional, mandatory if server uses CA file
98
+ :param ssl_cert_file:
99
+ A path to an SSL certificate file.
100
+ Optional, mandatory if server uses CA file
101
+ :param ssl_ca_file:
102
+ A path to a trusted certificate authorities (CA) file.
103
+ Optional
104
+ :param ssl_ciphers:
105
+ A colon-separated (:) list of SSL cipher suites
106
+ the connection can use. Optional
81
107
:param username:
82
108
Username to use for auth
83
109
(if ``None`` you are connected as a guest)
@@ -93,6 +119,10 @@ def __init__(self, *,
93
119
be checked by Tarantool, so no errors will occur
94
120
:param connect_timeout:
95
121
Time in seconds how long to wait for connecting to socket
122
+ :param ssl_handshake_timeout:
123
+ Time in seconds to wait for the TLS handshake to complete
124
+ before aborting the connection (used only for a TLS
125
+ connection)
96
126
:param request_timeout:
97
127
Request timeout (in seconds) for all requests
98
128
(by default there is no timeout)
@@ -116,6 +146,13 @@ def __init__(self, *,
116
146
super ().__init__ ()
117
147
self ._host = host
118
148
self ._port = port
149
+
150
+ self ._parameter_transport = transport
151
+ self ._ssl_key_file = ssl_key_file
152
+ self ._ssl_cert_file = ssl_cert_file
153
+ self ._ssl_ca_file = ssl_ca_file
154
+ self ._ssl_ciphers = ssl_ciphers
155
+
119
156
self ._username = username
120
157
self ._password = password
121
158
self ._fetch_schema = False if fetch_schema is None else fetch_schema
@@ -131,6 +168,7 @@ def __init__(self, *,
131
168
self ._encoding = encoding or 'utf-8'
132
169
133
170
self ._connect_timeout = connect_timeout
171
+ self ._ssl_handshake_timeout = ssl_handshake_timeout
134
172
self ._reconnect_timeout = reconnect_timeout or 0
135
173
self ._request_timeout = request_timeout
136
174
self ._ping_timeout = ping_timeout or 0
@@ -220,6 +258,54 @@ def protocol_factory(self,
220
258
on_connection_lost = self .connection_lost ,
221
259
loop = self ._loop )
222
260
261
+ def _create_ssl_context (self ):
262
+ try :
263
+ if hasattr (ssl , 'TLSVersion' ):
264
+ # Since python 3.7
265
+ context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
266
+ # Reset to default OpenSSL values.
267
+ context .check_hostname = False
268
+ context .verify_mode = ssl .CERT_NONE
269
+ # Require TLSv1.2, because other protocol versions don't seem
270
+ # to support the GOST cipher.
271
+ context .minimum_version = ssl .TLSVersion .TLSv1_2
272
+ context .maximum_version = ssl .TLSVersion .TLSv1_2
273
+ else :
274
+ # Deprecated, but it works for python < 3.7
275
+ context = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
276
+
277
+ if self ._ssl_cert_file :
278
+ # If the password argument is not specified and a password is
279
+ # required, OpenSSL’s built-in password prompting mechanism
280
+ # will be used to interactively prompt the user for a password.
281
+ #
282
+ # We should disable this behaviour, because a python
283
+ # application that uses the connector unlikely assumes
284
+ # interaction with a human + a Tarantool implementation does
285
+ # not support this at least for now.
286
+ def password_raise_error ():
287
+ raise SSLError ("a password for decrypting the private " +
288
+ "key is unsupported" )
289
+ context .load_cert_chain (certfile = self ._ssl_cert_file ,
290
+ keyfile = self ._ssl_key_file ,
291
+ password = password_raise_error )
292
+
293
+ if self ._ssl_ca_file :
294
+ context .load_verify_locations (cafile = self ._ssl_ca_file )
295
+ context .verify_mode = ssl .CERT_REQUIRED
296
+ # A Tarantool implementation does not check hostname. We don't
297
+ # do that too. As a result we don't set here:
298
+ # context.check_hostname = True
299
+
300
+ if self ._ssl_ciphers :
301
+ context .set_ciphers (self ._ssl_ciphers )
302
+
303
+ return context
304
+ except SSLError as e :
305
+ raise
306
+ except Exception as e :
307
+ raise SSLError (e )
308
+
223
309
async def _connect (self , return_exceptions : bool = True ):
224
310
if self ._loop is None :
225
311
self ._loop = get_running_loop ()
@@ -246,6 +332,12 @@ async def full_connect():
246
332
while True :
247
333
connected_fut = _create_future (self ._loop )
248
334
335
+ ssl_context = None
336
+ ssl_handshake_timeout = None
337
+ if self ._parameter_transport == Transport .SSL :
338
+ ssl_context = self ._create_ssl_context ()
339
+ ssl_handshake_timeout = self ._ssl_handshake_timeout
340
+
249
341
if self ._host .startswith ('unix/' ):
250
342
unix_path = self ._port
251
343
assert isinstance (unix_path , str ), \
@@ -260,13 +352,16 @@ async def full_connect():
260
352
conn = self ._loop .create_unix_connection (
261
353
functools .partial (self .protocol_factory ,
262
354
connected_fut ),
263
- unix_path
264
- )
355
+ unix_path ,
356
+ ssl = ssl_context ,
357
+ ssl_handshake_timeout = ssl_handshake_timeout )
265
358
else :
266
359
conn = self ._loop .create_connection (
267
360
functools .partial (self .protocol_factory ,
268
361
connected_fut ),
269
- self ._host , self ._port )
362
+ self ._host , self ._port ,
363
+ ssl = ssl_context ,
364
+ ssl_handshake_timeout = ssl_handshake_timeout )
270
365
271
366
tr , pr = await conn
272
367
@@ -337,6 +432,8 @@ async def full_connect():
337
432
338
433
if return_exceptions :
339
434
self ._reconnect_task = None
435
+ if isinstance (e , ssl .SSLError ):
436
+ e = SSLError (e )
340
437
raise e
341
438
342
439
logger .exception (e )
0 commit comments