Skip to content

Commit 2b2c72f

Browse files
committed
Use callbacks for sasl handshake request / response
1 parent 6b801a8 commit 2b2c72f

File tree

1 file changed

+62
-59
lines changed

1 file changed

+62
-59
lines changed

kafka/conn.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ class BrokerConnection(object):
7474
'ssl_password': None,
7575
'api_version': (0, 8, 2), # default to most restrictive
7676
'state_change_callback': lambda conn: True,
77-
'sasl_mechanism': None,
77+
'sasl_mechanism': 'PLAIN',
7878
'sasl_plain_username': None,
7979
'sasl_plain_password': None
8080
}
81+
SASL_MECHANISMS = ('PLAIN',)
8182

8283
def __init__(self, host, port, afi, **configs):
8384
self.host = host
@@ -100,11 +101,19 @@ def __init__(self, host, port, afi, **configs):
100101
(socket.SOL_SOCKET, socket.SO_SNDBUF,
101102
self.config['send_buffer_bytes']))
102103

104+
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
105+
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
106+
'sasl_mechanism must be in ' + self.SASL_MECHANISMS)
107+
if self.config['sasl_mechanism'] == 'PLAIN':
108+
assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl'
109+
assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl'
110+
103111
self.state = ConnectionStates.DISCONNECTED
104112
self._sock = None
105113
self._ssl_context = None
106114
if self.config['ssl_context'] is not None:
107115
self._ssl_context = self.config['ssl_context']
116+
self._sasl_auth_future = None
108117
self._rbuffer = io.BytesIO()
109118
self._receiving = False
110119
self._next_payload_bytes = 0
@@ -224,8 +233,9 @@ def connect(self):
224233
self.config['state_change_callback'](self)
225234

226235
if self.state is ConnectionStates.AUTHENTICATING:
236+
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
227237
if self._try_authenticate():
228-
log.debug('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
238+
log.info('%s: Authenticated as %s', str(self), self.config['sasl_plain_username'])
229239
self.state = ConnectionStates.CONNECTED
230240
self.config['state_change_callback'](self)
231241

@@ -289,58 +299,44 @@ def _try_handshake(self):
289299
return False
290300

291301
def _try_authenticate(self):
292-
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
293-
294-
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
295-
log.warning('%s: Sending username and password in the clear', str(self))
296-
297-
# Build a SaslHandShakeRequest message
298-
correlation_id = self._next_correlation_id()
299-
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
300-
header = RequestHeader(request,
301-
correlation_id=correlation_id,
302-
client_id=self.config['client_id'])
303-
304-
message = b''.join([header.encode(), request.encode()])
305-
size = Int32.encode(len(message))
306-
307-
# Attempt to send it over our socket
308-
try:
309-
self._sock.setblocking(True)
310-
self._sock.sendall(size + message)
311-
self._sock.setblocking(False)
312-
except (AssertionError, ConnectionError) as e:
313-
log.exception("Error sending %s to %s", request, self)
314-
error = Errors.ConnectionError("%s: %s" % (str(self), e))
302+
assert self.config['api_version'] >= (0, 10) or self.config['api_version'] is None
303+
304+
if self._sasl_auth_future is None:
305+
# Build a SaslHandShakeRequest message
306+
request = SaslHandShakeRequest[0](self.config['sasl_mechanism'])
307+
future = Future()
308+
sasl_response = self._send(request)
309+
sasl_response.add_callback(self._handle_sasl_handshake_response, future)
310+
sasl_response.add_errback(lambda f, e: f.failure(e), future)
311+
self._sasl_auth_future = future
312+
self._recv()
313+
if self._sasl_auth_future.failed():
314+
raise self._sasl_auth_future.exception
315+
return self._sasl_auth_future.succeeded()
316+
317+
def _handle_sasl_handshake_response(self, future, response):
318+
error_type = Errors.for_code(response.error_code)
319+
if error_type is not Errors.NoError:
320+
error = error_type(self)
315321
self.close(error=error)
316-
return False
317-
318-
future = Future()
319-
ifr = InFlightRequest(request=request,
320-
correlation_id=correlation_id,
321-
response_type=request.RESPONSE_TYPE,
322-
future=future,
323-
timestamp=time.time())
324-
self.in_flight_requests.append(ifr)
325-
326-
# Listen for a reply and check that the server supports the PLAIN mechanism
327-
response = None
328-
while not response:
329-
response = self.recv()
330-
331-
if not response.error_code is 0:
332-
raise Errors.for_code(response.error_code)
322+
return future.failure(error_type(self))
333323

334-
if not self.config['sasl_mechanism'] in response.enabled_mechanisms:
335-
raise Errors.AuthenticationMethodNotSupported(self.config['sasl_mechanism'] + " is not supported by broker")
324+
if self.config['sasl_mechanism'] == 'PLAIN':
325+
return self._try_authenticate_plain(future)
326+
else:
327+
return future.failure(
328+
Errors.UnsupportedSaslMechanismError(
329+
'kafka-python does not support SASL mechanism %s' %
330+
self.config['sasl_mechanism']))
336331

337-
return self._try_authenticate_plain()
332+
def _try_authenticate_plain(self, future):
333+
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
334+
log.warning('%s: Sending username and password in the clear', str(self))
338335

339-
def _try_authenticate_plain(self):
340336
data = b''
341337
try:
342338
self._sock.setblocking(True)
343-
# Send our credentials
339+
# Send PLAIN credentials per RFC-4616
344340
msg = bytes('\0'.join([self.config['sasl_plain_username'],
345341
self.config['sasl_plain_username'],
346342
self.config['sasl_plain_password']]).encode('utf-8'))
@@ -351,26 +347,26 @@ def _try_authenticate_plain(self):
351347
# The connection is closed on failure
352348
received_bytes = 0
353349
while received_bytes < 4:
354-
data = data + self._sock.recv(4 - received_bytes)
355-
received_bytes = received_bytes + len(data)
350+
data += self._sock.recv(4 - received_bytes)
351+
received_bytes += len(data)
356352
if not data:
357353
log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username'])
358-
self.close(error=Errors.ConnectionError('Authentication failed'))
359-
raise Errors.AuthenticationFailedError('Authentication failed for user {}'.format(self.config['sasl_plain_username']))
354+
error = Errors.AuthenticationFailedError(
355+
'Authentication failed for user {0}'.format(
356+
self.config['sasl_plain_username']))
357+
future.failure(error)
358+
raise error
360359
self._sock.setblocking(False)
361360
except (AssertionError, ConnectionError) as e:
362361
log.exception("%s: Error receiving reply from server", self)
363362
error = Errors.ConnectionError("%s: %s" % (str(self), e))
363+
future.failure(error)
364364
self.close(error=error)
365-
return False
366365

367-
with io.BytesIO() as buffer:
368-
buffer.write(data)
369-
buffer.seek(0)
370-
if not Int32.decode(buffer) == 0:
371-
raise Errors.KafkaError('Expected a zero sized reply after sending credentials')
366+
if data != '\x00\x00\x00\x00':
367+
return future.failure(Errors.AuthenticationFailedError())
372368

373-
return True
369+
return future.success(True)
374370

375371
def blacked_out(self):
376372
"""
@@ -437,6 +433,10 @@ def send(self, request, expect_response=True):
437433
return future.failure(Errors.ConnectionError(str(self)))
438434
elif not self.can_send_more():
439435
return future.failure(Errors.TooManyInFlightRequests(str(self)))
436+
return self._send(request, expect_response=expect_response)
437+
438+
def _send(self, request, expect_response=True):
439+
future = Future()
440440
correlation_id = self._next_correlation_id()
441441
header = RequestHeader(request,
442442
correlation_id=correlation_id,
@@ -505,6 +505,9 @@ def recv(self):
505505
self.config['request_timeout_ms']))
506506
return None
507507

508+
return self._recv()
509+
510+
def _recv(self):
508511
# Not receiving is the state of reading the payload header
509512
if not self._receiving:
510513
try:
@@ -552,7 +555,7 @@ def recv(self):
552555
# enough data to read the full bytes_to_read
553556
# but if the socket is disconnected, we will get empty data
554557
# without an exception raised
555-
if not data:
558+
if bytes_to_read and not data:
556559
log.error('%s: socket disconnected', self)
557560
self.close(error=Errors.ConnectionError('socket disconnected'))
558561
return None

0 commit comments

Comments
 (0)