15
15
import kafka .errors as Errors
16
16
from kafka .future import Future
17
17
from kafka .protocol .api import RequestHeader
18
+ from kafka .protocol .admin import SaslHandShakeRequest , SaslHandShakeResponse
18
19
from kafka .protocol .commit import GroupCoordinatorResponse
19
20
from kafka .protocol .types import Int32
20
21
from kafka .version import __version__
@@ -48,7 +49,7 @@ class ConnectionStates(object):
48
49
CONNECTING = '<connecting>'
49
50
HANDSHAKE = '<handshake>'
50
51
CONNECTED = '<connected>'
51
-
52
+ AUTHENTICATING = '<authenticating>'
52
53
53
54
InFlightRequest = collections .namedtuple ('InFlightRequest' ,
54
55
['request' , 'response_type' , 'correlation_id' , 'future' , 'timestamp' ])
@@ -73,7 +74,11 @@ class BrokerConnection(object):
73
74
'ssl_password' : None ,
74
75
'api_version' : (0 , 8 , 2 ), # default to most restrictive
75
76
'state_change_callback' : lambda conn : True ,
77
+ 'sasl_mechanism' : 'PLAIN' ,
78
+ 'sasl_plain_username' : None ,
79
+ 'sasl_plain_password' : None
76
80
}
81
+ SASL_MECHANISMS = ('PLAIN' ,)
77
82
78
83
def __init__ (self , host , port , afi , ** configs ):
79
84
self .host = host
@@ -96,11 +101,19 @@ def __init__(self, host, port, afi, **configs):
96
101
(socket .SOL_SOCKET , socket .SO_SNDBUF ,
97
102
self .config ['send_buffer_bytes' ]))
98
103
104
+ if self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' ):
105
+ assert self .config ['sasl_mechanism' ] in self .SASL_MECHANISMS , (
106
+ 'sasl_mechanism must be in ' + self .SASL_MECHANISMS )
107
+ if self .config ['sasl_mechanism' ] == 'PLAIN' :
108
+ assert self .config ['sasl_plain_username' ] is not None , 'sasl_plain_username required for PLAIN sasl'
109
+ assert self .config ['sasl_plain_password' ] is not None , 'sasl_plain_password required for PLAIN sasl'
110
+
99
111
self .state = ConnectionStates .DISCONNECTED
100
112
self ._sock = None
101
113
self ._ssl_context = None
102
114
if self .config ['ssl_context' ] is not None :
103
115
self ._ssl_context = self .config ['ssl_context' ]
116
+ self ._sasl_auth_future = None
104
117
self ._rbuffer = io .BytesIO ()
105
118
self ._receiving = False
106
119
self ._next_payload_bytes = 0
@@ -188,6 +201,8 @@ def connect(self):
188
201
if self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' ):
189
202
log .debug ('%s: initiating SSL handshake' , str (self ))
190
203
self .state = ConnectionStates .HANDSHAKE
204
+ elif self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
205
+ self .state = ConnectionStates .AUTHENTICATING
191
206
else :
192
207
self .state = ConnectionStates .CONNECTED
193
208
self .config ['state_change_callback' ](self )
@@ -211,6 +226,16 @@ def connect(self):
211
226
if self .state is ConnectionStates .HANDSHAKE :
212
227
if self ._try_handshake ():
213
228
log .debug ('%s: completed SSL handshake.' , str (self ))
229
+ if self .config ['security_protocol' ] == 'SASL_SSL' :
230
+ self .state = ConnectionStates .AUTHENTICATING
231
+ else :
232
+ self .state = ConnectionStates .CONNECTED
233
+ self .config ['state_change_callback' ](self )
234
+
235
+ if self .state is ConnectionStates .AUTHENTICATING :
236
+ assert self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' )
237
+ if self ._try_authenticate ():
238
+ log .info ('%s: Authenticated as %s' , str (self ), self .config ['sasl_plain_username' ])
214
239
self .state = ConnectionStates .CONNECTED
215
240
self .config ['state_change_callback' ](self )
216
241
@@ -273,6 +298,75 @@ def _try_handshake(self):
273
298
274
299
return False
275
300
301
+ def _try_authenticate (self ):
302
+ assert self .config ['api_version' ] is None or self .config ['api_version' ] >= (0 , 10 )
303
+
304
+ if self ._sasl_auth_future is None :
305
+ # Build a SaslHandShakeRequest message
306
+ request = SaslHandShakeRequest [0 ](self .config ['sasl_mechanism' ])
307
+ future = Future ()
308
+ sasl_response = self ._send (request )
309
+ sasl_response .add_callback (self ._handle_sasl_handshake_response , future )
310
+ sasl_response .add_errback (lambda f , e : f .failure (e ), future )
311
+ self ._sasl_auth_future = future
312
+ self ._recv ()
313
+ if self ._sasl_auth_future .failed ():
314
+ raise self ._sasl_auth_future .exception # pylint: disable-msg=raising-bad-type
315
+ return self ._sasl_auth_future .succeeded ()
316
+
317
+ def _handle_sasl_handshake_response (self , future , response ):
318
+ error_type = Errors .for_code (response .error_code )
319
+ if error_type is not Errors .NoError :
320
+ error = error_type (self )
321
+ self .close (error = error )
322
+ return future .failure (error_type (self ))
323
+
324
+ if self .config ['sasl_mechanism' ] == 'PLAIN' :
325
+ return self ._try_authenticate_plain (future )
326
+ else :
327
+ return future .failure (
328
+ Errors .UnsupportedSaslMechanismError (
329
+ 'kafka-python does not support SASL mechanism %s' %
330
+ self .config ['sasl_mechanism' ]))
331
+
332
+ def _try_authenticate_plain (self , future ):
333
+ if self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
334
+ log .warning ('%s: Sending username and password in the clear' , str (self ))
335
+
336
+ data = b''
337
+ try :
338
+ self ._sock .setblocking (True )
339
+ # Send PLAIN credentials per RFC-4616
340
+ msg = bytes ('\0 ' .join ([self .config ['sasl_plain_username' ],
341
+ self .config ['sasl_plain_username' ],
342
+ self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
343
+ size = Int32 .encode (len (msg ))
344
+ self ._sock .sendall (size + msg )
345
+
346
+ # The server will send a zero sized message (that is Int32(0)) on success.
347
+ # The connection is closed on failure
348
+ while len (data ) < 4 :
349
+ fragment = self ._sock .recv (4 - len (data ))
350
+ if not fragment :
351
+ log .error ('%s: Authentication failed for user %s' , self , self .config ['sasl_plain_username' ])
352
+ error = Errors .AuthenticationFailedError (
353
+ 'Authentication failed for user {0}' .format (
354
+ self .config ['sasl_plain_username' ]))
355
+ future .failure (error )
356
+ raise error
357
+ data += fragment
358
+ self ._sock .setblocking (False )
359
+ except (AssertionError , ConnectionError ) as e :
360
+ log .exception ("%s: Error receiving reply from server" , self )
361
+ error = Errors .ConnectionError ("%s: %s" % (str (self ), e ))
362
+ future .failure (error )
363
+ self .close (error = error )
364
+
365
+ if data != b'\x00 \x00 \x00 \x00 ' :
366
+ return future .failure (Errors .AuthenticationFailedError ())
367
+
368
+ return future .success (True )
369
+
276
370
def blacked_out (self ):
277
371
"""
278
372
Return true if we are disconnected from the given node and can't
@@ -292,7 +386,8 @@ def connecting(self):
292
386
"""Returns True if still connecting (this may encompass several
293
387
different states, such as SSL handshake, authorization, etc)."""
294
388
return self .state in (ConnectionStates .CONNECTING ,
295
- ConnectionStates .HANDSHAKE )
389
+ ConnectionStates .HANDSHAKE ,
390
+ ConnectionStates .AUTHENTICATING )
296
391
297
392
def disconnected (self ):
298
393
"""Return True iff socket is closed"""
@@ -337,6 +432,10 @@ def send(self, request, expect_response=True):
337
432
return future .failure (Errors .ConnectionError (str (self )))
338
433
elif not self .can_send_more ():
339
434
return future .failure (Errors .TooManyInFlightRequests (str (self )))
435
+ return self ._send (request , expect_response = expect_response )
436
+
437
+ def _send (self , request , expect_response = True ):
438
+ future = Future ()
340
439
correlation_id = self ._next_correlation_id ()
341
440
header = RequestHeader (request ,
342
441
correlation_id = correlation_id ,
@@ -385,7 +484,7 @@ def recv(self):
385
484
Return response if available
386
485
"""
387
486
assert not self ._processing , 'Recursion not supported'
388
- if not self .connected ():
487
+ if not self .connected () and not self . state is ConnectionStates . AUTHENTICATING :
389
488
log .warning ('%s cannot recv: socket not connected' , self )
390
489
# If requests are pending, we should close the socket and
391
490
# fail all the pending request futures
@@ -405,6 +504,9 @@ def recv(self):
405
504
self .config ['request_timeout_ms' ]))
406
505
return None
407
506
507
+ return self ._recv ()
508
+
509
+ def _recv (self ):
408
510
# Not receiving is the state of reading the payload header
409
511
if not self ._receiving :
410
512
try :
@@ -452,7 +554,7 @@ def recv(self):
452
554
# enough data to read the full bytes_to_read
453
555
# but if the socket is disconnected, we will get empty data
454
556
# without an exception raised
455
- if not data :
557
+ if bytes_to_read and not data :
456
558
log .error ('%s: socket disconnected' , self )
457
559
self .close (error = Errors .ConnectionError ('socket disconnected' ))
458
560
return None
0 commit comments