@@ -74,10 +74,11 @@ class BrokerConnection(object):
74
74
'ssl_password' : None ,
75
75
'api_version' : (0 , 8 , 2 ), # default to most restrictive
76
76
'state_change_callback' : lambda conn : True ,
77
- 'sasl_mechanism' : None ,
77
+ 'sasl_mechanism' : 'PLAIN' ,
78
78
'sasl_plain_username' : None ,
79
79
'sasl_plain_password' : None
80
80
}
81
+ SASL_MECHANISMS = ('PLAIN' ,)
81
82
82
83
def __init__ (self , host , port , afi , ** configs ):
83
84
self .host = host
@@ -100,11 +101,19 @@ def __init__(self, host, port, afi, **configs):
100
101
(socket .SOL_SOCKET , socket .SO_SNDBUF ,
101
102
self .config ['send_buffer_bytes' ]))
102
103
104
+ if self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' ):
105
+ assert self .config ['sasl_mechanism' ] in self .SASL_MECHANISMS , (
106
+ 'sasl_mechanism must be in ' + self .SASL_MECHANISMS )
107
+ if self .config ['sasl_mechanism' ] == 'PLAIN' :
108
+ assert self .config ['sasl_plain_username' ] is not None , 'sasl_plain_username required for PLAIN sasl'
109
+ assert self .config ['sasl_plain_password' ] is not None , 'sasl_plain_password required for PLAIN sasl'
110
+
103
111
self .state = ConnectionStates .DISCONNECTED
104
112
self ._sock = None
105
113
self ._ssl_context = None
106
114
if self .config ['ssl_context' ] is not None :
107
115
self ._ssl_context = self .config ['ssl_context' ]
116
+ self ._sasl_auth_future = None
108
117
self ._rbuffer = io .BytesIO ()
109
118
self ._receiving = False
110
119
self ._next_payload_bytes = 0
@@ -224,8 +233,9 @@ def connect(self):
224
233
self .config ['state_change_callback' ](self )
225
234
226
235
if self .state is ConnectionStates .AUTHENTICATING :
236
+ assert self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' )
227
237
if self ._try_authenticate ():
228
- log .debug ('%s: Authenticated as %s' , str (self ), self .config ['sasl_plain_username' ])
238
+ log .info ('%s: Authenticated as %s' , str (self ), self .config ['sasl_plain_username' ])
229
239
self .state = ConnectionStates .CONNECTED
230
240
self .config ['state_change_callback' ](self )
231
241
@@ -289,58 +299,44 @@ def _try_handshake(self):
289
299
return False
290
300
291
301
def _try_authenticate (self ):
292
- assert self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' )
293
-
294
- if self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
295
- log .warning ('%s: Sending username and password in the clear' , str (self ))
296
-
297
- # Build a SaslHandShakeRequest message
298
- correlation_id = self ._next_correlation_id ()
299
- request = SaslHandShakeRequest [0 ](self .config ['sasl_mechanism' ])
300
- header = RequestHeader (request ,
301
- correlation_id = correlation_id ,
302
- client_id = self .config ['client_id' ])
303
-
304
- message = b'' .join ([header .encode (), request .encode ()])
305
- size = Int32 .encode (len (message ))
306
-
307
- # Attempt to send it over our socket
308
- try :
309
- self ._sock .setblocking (True )
310
- self ._sock .sendall (size + message )
311
- self ._sock .setblocking (False )
312
- except (AssertionError , ConnectionError ) as e :
313
- log .exception ("Error sending %s to %s" , request , self )
314
- error = Errors .ConnectionError ("%s: %s" % (str (self ), e ))
302
+ assert self .config ['api_version' ] >= (0 , 10 ) or self .config ['api_version' ] is None
303
+
304
+ if self ._sasl_auth_future is None :
305
+ # Build a SaslHandShakeRequest message
306
+ request = SaslHandShakeRequest [0 ](self .config ['sasl_mechanism' ])
307
+ future = Future ()
308
+ sasl_response = self ._send (request )
309
+ sasl_response .add_callback (self ._handle_sasl_handshake_response , future )
310
+ sasl_response .add_errback (lambda f , e : f .failure (e ), future )
311
+ self ._sasl_auth_future = future
312
+ self ._recv ()
313
+ if self ._sasl_auth_future .failed ():
314
+ raise self ._sasl_auth_future .exception
315
+ return self ._sasl_auth_future .succeeded ()
316
+
317
+ def _handle_sasl_handshake_response (self , future , response ):
318
+ error_type = Errors .for_code (response .error_code )
319
+ if error_type is not Errors .NoError :
320
+ error = error_type (self )
315
321
self .close (error = error )
316
- return False
317
-
318
- future = Future ()
319
- ifr = InFlightRequest (request = request ,
320
- correlation_id = correlation_id ,
321
- response_type = request .RESPONSE_TYPE ,
322
- future = future ,
323
- timestamp = time .time ())
324
- self .in_flight_requests .append (ifr )
325
-
326
- # Listen for a reply and check that the server supports the PLAIN mechanism
327
- response = None
328
- while not response :
329
- response = self .recv ()
330
-
331
- if not response .error_code is 0 :
332
- raise Errors .for_code (response .error_code )
322
+ return future .failure (error_type (self ))
333
323
334
- if not self .config ['sasl_mechanism' ] in response .enabled_mechanisms :
335
- raise Errors .AuthenticationMethodNotSupported (self .config ['sasl_mechanism' ] + " is not supported by broker" )
324
+ if self .config ['sasl_mechanism' ] == 'PLAIN' :
325
+ return self ._try_authenticate_plain (future )
326
+ else :
327
+ return future .failure (
328
+ Errors .UnsupportedSaslMechanismError (
329
+ 'kafka-python does not support SASL mechanism %s' %
330
+ self .config ['sasl_mechanism' ]))
336
331
337
- return self ._try_authenticate_plain ()
332
+ def _try_authenticate_plain (self , future ):
333
+ if self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
334
+ log .warning ('%s: Sending username and password in the clear' , str (self ))
338
335
339
- def _try_authenticate_plain (self ):
340
336
data = b''
341
337
try :
342
338
self ._sock .setblocking (True )
343
- # Send our credentials
339
+ # Send PLAIN credentials per RFC-4616
344
340
msg = bytes ('\0 ' .join ([self .config ['sasl_plain_username' ],
345
341
self .config ['sasl_plain_username' ],
346
342
self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
@@ -351,26 +347,26 @@ def _try_authenticate_plain(self):
351
347
# The connection is closed on failure
352
348
received_bytes = 0
353
349
while received_bytes < 4 :
354
- data = data + self ._sock .recv (4 - received_bytes )
355
- received_bytes = received_bytes + len (data )
350
+ data += self ._sock .recv (4 - received_bytes )
351
+ received_bytes += len (data )
356
352
if not data :
357
353
log .error ('%s: Authentication failed for user %s' , self , self .config ['sasl_plain_username' ])
358
- self .close (error = Errors .ConnectionError ('Authentication failed' ))
359
- raise Errors .AuthenticationFailedError ('Authentication failed for user {}' .format (self .config ['sasl_plain_username' ]))
354
+ error = Errors .AuthenticationFailedError (
355
+ 'Authentication failed for user {0}' .format (
356
+ self .config ['sasl_plain_username' ]))
357
+ future .failure (error )
358
+ raise error
360
359
self ._sock .setblocking (False )
361
360
except (AssertionError , ConnectionError ) as e :
362
361
log .exception ("%s: Error receiving reply from server" , self )
363
362
error = Errors .ConnectionError ("%s: %s" % (str (self ), e ))
363
+ future .failure (error )
364
364
self .close (error = error )
365
- return False
366
365
367
- with io .BytesIO () as buffer :
368
- buffer .write (data )
369
- buffer .seek (0 )
370
- if not Int32 .decode (buffer ) == 0 :
371
- raise Errors .KafkaError ('Expected a zero sized reply after sending credentials' )
366
+ if data != '\x00 \x00 \x00 \x00 ' :
367
+ return future .failure (Errors .AuthenticationFailedError ())
372
368
373
- return True
369
+ return future . success ( True )
374
370
375
371
def blacked_out (self ):
376
372
"""
@@ -437,6 +433,10 @@ def send(self, request, expect_response=True):
437
433
return future .failure (Errors .ConnectionError (str (self )))
438
434
elif not self .can_send_more ():
439
435
return future .failure (Errors .TooManyInFlightRequests (str (self )))
436
+ return self ._send (request , expect_response = expect_response )
437
+
438
+ def _send (self , request , expect_response = True ):
439
+ future = Future ()
440
440
correlation_id = self ._next_correlation_id ()
441
441
header = RequestHeader (request ,
442
442
correlation_id = correlation_id ,
@@ -505,6 +505,9 @@ def recv(self):
505
505
self .config ['request_timeout_ms' ]))
506
506
return None
507
507
508
+ return self ._recv ()
509
+
510
+ def _recv (self ):
508
511
# Not receiving is the state of reading the payload header
509
512
if not self ._receiving :
510
513
try :
@@ -552,7 +555,7 @@ def recv(self):
552
555
# enough data to read the full bytes_to_read
553
556
# but if the socket is disconnected, we will get empty data
554
557
# without an exception raised
555
- if not data :
558
+ if bytes_to_read and not data :
556
559
log .error ('%s: socket disconnected' , self )
557
560
self .close (error = Errors .ConnectionError ('socket disconnected' ))
558
561
return None
0 commit comments