Skip to content

Commit d73d169

Browse files
committed
Merge pull request #196 from dpkp/reinit_connection_error
Improve KafkaConnection with more tests
2 parents e289336 + f862774 commit d73d169

File tree

2 files changed

+177
-40
lines changed

2 files changed

+177
-40
lines changed

kafka/conn.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,38 @@ def __repr__(self):
6767
###################
6868

6969
def _raise_connection_error(self):
70-
self._dirty = True
70+
# Cleanup socket if we have one
71+
if self._sock:
72+
self.close()
73+
74+
# And then raise
7175
raise ConnectionError("Kafka @ {0}:{1} went away".format(self.host, self.port))
7276

7377
def _read_bytes(self, num_bytes):
7478
bytes_left = num_bytes
7579
responses = []
7680

7781
log.debug("About to read %d bytes from Kafka", num_bytes)
78-
if self._dirty:
82+
83+
# Make sure we have a connection
84+
if not self._sock:
7985
self.reinit()
8086

8187
while bytes_left:
88+
8289
try:
8390
data = self._sock.recv(min(bytes_left, 4096))
91+
92+
# Receiving empty string from recv signals
93+
# that the socket is in error. we will never get
94+
# more data from this socket
95+
if data == '':
96+
raise socket.error('Not enough data to read message -- did server kill socket?')
97+
8498
except socket.error:
8599
log.exception('Unable to receive data from Kafka')
86100
self._raise_connection_error()
87101

88-
if data == '':
89-
log.error("Not enough data to read this response")
90-
self._raise_connection_error()
91-
92102
bytes_left -= len(data)
93103
log.debug("Read %d/%d bytes from Kafka", num_bytes - bytes_left, num_bytes)
94104
responses.append(data)
@@ -102,26 +112,34 @@ def _read_bytes(self, num_bytes):
102112
# TODO multiplex socket communication to allow for multi-threaded clients
103113

104114
def send(self, request_id, payload):
105-
"Send a request to Kafka"
115+
"""
116+
Send a request to Kafka
117+
param: request_id -- can be any int (used only for debug logging...)
118+
param: payload -- an encoded kafka packet (see KafkaProtocol)
119+
"""
120+
106121
log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), request_id))
122+
123+
# Make sure we have a connection
124+
if not self._sock:
125+
self.reinit()
126+
107127
try:
108-
if self._dirty:
109-
self.reinit()
110-
sent = self._sock.sendall(payload)
111-
if sent is not None:
112-
self._raise_connection_error()
128+
self._sock.sendall(payload)
113129
except socket.error:
114130
log.exception('Unable to send payload to Kafka')
115131
self._raise_connection_error()
116132

117133
def recv(self, request_id):
118134
"""
119-
Get a response from Kafka
135+
Get a response packet from Kafka
136+
param: request_id -- can be any int (only used for debug logging...)
137+
returns encoded kafka packet response from server as type str
120138
"""
121139
log.debug("Reading response %d from Kafka" % request_id)
140+
122141
# Read the size off of the header
123142
resp = self._read_bytes(4)
124-
125143
(size,) = struct.unpack('>i', resp)
126144

127145
# Read the remainder of the response
@@ -132,22 +150,46 @@ def copy(self):
132150
"""
133151
Create an inactive copy of the connection object
134152
A reinit() has to be done on the copy before it can be used again
153+
return a new KafkaConnection object
135154
"""
136155
c = copy.deepcopy(self)
137156
c._sock = None
138157
return c
139158

140159
def close(self):
141160
"""
142-
Close this connection
161+
Shutdown and close the connection socket
143162
"""
163+
log.debug("Closing socket connection for %s:%d" % (self.host, self.port))
144164
if self._sock:
165+
# Call shutdown to be a good TCP client
166+
# But expect an error if the socket has already been
167+
# closed by the server
168+
try:
169+
self._sock.shutdown(socket.SHUT_RDWR)
170+
except socket.error:
171+
pass
172+
173+
# Closing the socket should always succeed
145174
self._sock.close()
175+
self._sock = None
176+
else:
177+
log.debug("No socket found to close!")
146178

147179
def reinit(self):
148180
"""
149181
Re-initialize the socket connection
182+
close current socket (if open)
183+
and start a fresh connection
184+
raise ConnectionError on error
150185
"""
151-
self.close()
152-
self._sock = socket.create_connection((self.host, self.port), self.timeout)
153-
self._dirty = False
186+
log.debug("Reinitializing socket connection for %s:%d" % (self.host, self.port))
187+
188+
if self._sock:
189+
self.close()
190+
191+
try:
192+
self._sock = socket.create_connection((self.host, self.port), self.timeout)
193+
except socket.error:
194+
log.exception('Unable to connect to kafka broker at %s:%d' % (self.host, self.port))
195+
self._raise_connection_error()

test/test_conn.py

Lines changed: 117 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
1-
import os
2-
import random
1+
import socket
32
import struct
3+
4+
import mock
45
import unittest2
5-
import kafka.conn
6+
7+
from kafka.common import *
8+
from kafka.conn import *
69

710
class ConnTest(unittest2.TestCase):
11+
def setUp(self):
12+
self.config = {
13+
'host': 'localhost',
14+
'port': 9090,
15+
'request_id': 0,
16+
'payload': 'test data',
17+
'payload2': 'another packet'
18+
}
19+
20+
# Mocking socket.create_connection will cause _sock to always be a
21+
# MagicMock()
22+
patcher = mock.patch('socket.create_connection', spec=True)
23+
self.MockCreateConn = patcher.start()
24+
self.addCleanup(patcher.stop)
25+
26+
# Also mock socket.sendall() to appear successful
27+
socket.create_connection().sendall.return_value = None
28+
29+
# And mock socket.recv() to return two payloads, then '', then raise
30+
# Note that this currently ignores the num_bytes parameter to sock.recv()
31+
payload_size = len(self.config['payload'])
32+
payload2_size = len(self.config['payload2'])
33+
socket.create_connection().recv.side_effect = [
34+
struct.pack('>i', payload_size),
35+
struct.pack('>%ds' % payload_size, self.config['payload']),
36+
struct.pack('>i', payload2_size),
37+
struct.pack('>%ds' % payload2_size, self.config['payload2']),
38+
''
39+
]
40+
41+
# Create a connection object
42+
self.conn = KafkaConnection(self.config['host'], self.config['port'])
43+
44+
# Reset any mock counts caused by __init__
45+
socket.create_connection.reset_mock()
46+
847
def test_collect_hosts__happy_path(self):
948
hosts = "localhost:1234,localhost"
10-
results = kafka.conn.collect_hosts(hosts)
49+
results = collect_hosts(hosts)
1150

1251
self.assertEqual(set(results), set([
1352
('localhost', 1234),
@@ -20,7 +59,7 @@ def test_collect_hosts__string_list(self):
2059
'localhost',
2160
]
2261

23-
results = kafka.conn.collect_hosts(hosts)
62+
results = collect_hosts(hosts)
2463

2564
self.assertEqual(set(results), set([
2665
('localhost', 1234),
@@ -29,41 +68,97 @@ def test_collect_hosts__string_list(self):
2968

3069
def test_collect_hosts__with_spaces(self):
3170
hosts = "localhost:1234, localhost"
32-
results = kafka.conn.collect_hosts(hosts)
71+
results = collect_hosts(hosts)
3372

3473
self.assertEqual(set(results), set([
3574
('localhost', 1234),
3675
('localhost', 9092),
3776
]))
3877

39-
@unittest2.skip("Not Implemented")
4078
def test_send(self):
41-
pass
79+
self.conn.send(self.config['request_id'], self.config['payload'])
80+
self.conn._sock.sendall.assert_called_with(self.config['payload'])
81+
82+
def test_init_creates_socket_connection(self):
83+
KafkaConnection(self.config['host'], self.config['port'])
84+
socket.create_connection.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS)
85+
86+
def test_init_failure_raises_connection_error(self):
87+
88+
def raise_error(*args):
89+
raise socket.error
90+
91+
assert socket.create_connection is self.MockCreateConn
92+
socket.create_connection.side_effect=raise_error
93+
with self.assertRaises(ConnectionError):
94+
KafkaConnection(self.config['host'], self.config['port'])
4295

43-
@unittest2.skip("Not Implemented")
4496
def test_send__reconnects_on_dirty_conn(self):
45-
pass
4697

47-
@unittest2.skip("Not Implemented")
98+
# Dirty the connection
99+
try:
100+
self.conn._raise_connection_error()
101+
except ConnectionError:
102+
pass
103+
104+
# Now test that sending attempts to reconnect
105+
self.assertEqual(socket.create_connection.call_count, 0)
106+
self.conn.send(self.config['request_id'], self.config['payload'])
107+
self.assertEqual(socket.create_connection.call_count, 1)
108+
48109
def test_send__failure_sets_dirty_connection(self):
49-
pass
50110

51-
@unittest2.skip("Not Implemented")
111+
def raise_error(*args):
112+
raise socket.error
113+
114+
assert isinstance(self.conn._sock, mock.Mock)
115+
self.conn._sock.sendall.side_effect=raise_error
116+
try:
117+
self.conn.send(self.config['request_id'], self.config['payload'])
118+
except ConnectionError:
119+
self.assertIsNone(self.conn._sock)
120+
52121
def test_recv(self):
53-
pass
54122

55-
@unittest2.skip("Not Implemented")
123+
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])
124+
56125
def test_recv__reconnects_on_dirty_conn(self):
57-
pass
58126

59-
@unittest2.skip("Not Implemented")
127+
# Dirty the connection
128+
try:
129+
self.conn._raise_connection_error()
130+
except ConnectionError:
131+
pass
132+
133+
# Now test that recv'ing attempts to reconnect
134+
self.assertEqual(socket.create_connection.call_count, 0)
135+
self.conn.recv(self.config['request_id'])
136+
self.assertEqual(socket.create_connection.call_count, 1)
137+
60138
def test_recv__failure_sets_dirty_connection(self):
61-
pass
62139

63-
@unittest2.skip("Not Implemented")
140+
def raise_error(*args):
141+
raise socket.error
142+
143+
# test that recv'ing attempts to reconnect
144+
assert isinstance(self.conn._sock, mock.Mock)
145+
self.conn._sock.recv.side_effect=raise_error
146+
try:
147+
self.conn.recv(self.config['request_id'])
148+
except ConnectionError:
149+
self.assertIsNone(self.conn._sock)
150+
64151
def test_recv__doesnt_consume_extra_data_in_stream(self):
65-
pass
66152

67-
@unittest2.skip("Not Implemented")
153+
# Here just test that each call to recv will return a single payload
154+
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])
155+
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload2'])
156+
68157
def test_close__object_is_reusable(self):
69-
pass
158+
159+
# test that sending to a closed connection
160+
# will re-connect and send data to the socket
161+
self.conn.close()
162+
self.conn.send(self.config['request_id'], self.config['payload'])
163+
self.assertEqual(socket.create_connection.call_count, 1)
164+
self.conn._sock.sendall.assert_called_with(self.config['payload'])

0 commit comments

Comments
 (0)