@@ -13,21 +13,35 @@ def setUp(self):
13
13
'host' : 'localhost' ,
14
14
'port' : 9090 ,
15
15
'request_id' : 0 ,
16
- 'payload' : 'test data'
16
+ 'payload' : 'test data' ,
17
+ 'payload2' : 'another packet'
17
18
}
18
19
19
20
# Mocking socket.create_connection will cause _sock to always be a
20
21
# MagicMock()
21
22
patcher = mock .patch ('socket.create_connection' , spec = True )
22
23
self .MockCreateConn = patcher .start ()
24
+ self .addCleanup (patcher .stop )
23
25
24
26
# Also mock socket.sendall() to appear successful
25
- self .MockCreateConn ().sendall .return_value = None
26
- self .addCleanup (patcher .stop )
27
+ socket .create_connection ().sendall .return_value = None
28
+
29
+ # And mock socket.recv() to return two payloads, then '', then raise
30
+ # Note that this currently ignores the num_bytes parameter to sock.recv()
31
+ payload_size = len (self .config ['payload' ])
32
+ payload2_size = len (self .config ['payload2' ])
33
+ socket .create_connection ().recv .side_effect = [
34
+ struct .pack ('>i' , payload_size ),
35
+ struct .pack ('>%ds' % payload_size , self .config ['payload' ]),
36
+ struct .pack ('>i' , payload2_size ),
37
+ struct .pack ('>%ds' % payload2_size , self .config ['payload2' ]),
38
+ ''
39
+ ]
27
40
28
- # And mock socket.recv() to return the payload
29
- self .MockCreateConn ().recv .return_value = self .config ['payload' ]
41
+ # Create a connection object
30
42
self .conn = KafkaConnection (self .config ['host' ], self .config ['port' ])
43
+
44
+ # Reset any mock counts caused by __init__
31
45
socket .create_connection .reset_mock ()
32
46
33
47
def test_collect_hosts__happy_path (self ):
@@ -92,17 +106,6 @@ def test_send__reconnects_on_dirty_conn(self):
92
106
self .conn .send (self .config ['request_id' ], self .config ['payload' ])
93
107
self .assertEqual (socket .create_connection .call_count , 1 )
94
108
95
- # A second way to dirty it...
96
- self .conn .close ()
97
-
98
- # Reset the socket call counts
99
- socket .create_connection .reset_mock ()
100
- self .assertEqual (socket .create_connection .call_count , 0 )
101
-
102
- # Now test that sending attempts to reconnect
103
- self .conn .send (self .config ['request_id' ], self .config ['payload' ])
104
- self .assertEqual (socket .create_connection .call_count , 1 )
105
-
106
109
def test_send__failure_sets_dirty_connection (self ):
107
110
108
111
def raise_error (* args ):
@@ -117,21 +120,7 @@ def raise_error(*args):
117
120
118
121
def test_recv (self ):
119
122
120
- # A function to mock _read_bytes
121
- self .conn ._mock_sent_size = False
122
- self .conn ._mock_data_sent = 0
123
- def mock_socket_recv (num_bytes ):
124
- if not self .conn ._mock_sent_size :
125
- assert num_bytes == 4
126
- self .conn ._mock_sent_size = True
127
- return struct .pack ('>i' , len (self .config ['payload' ]))
128
-
129
- recv_data = struct .pack ('>%ds' % num_bytes , self .config ['payload' ][self .conn ._mock_data_sent :self .conn ._mock_data_sent + num_bytes ])
130
- self .conn ._mock_data_sent += num_bytes
131
- return recv_data
132
-
133
- with mock .patch .object (self .conn , '_read_bytes' , new = mock_socket_recv ):
134
- self .assertEquals (self .conn .recv (self .config ['request_id' ]), self .config ['payload' ])
123
+ self .assertEquals (self .conn .recv (self .config ['request_id' ]), self .config ['payload' ])
135
124
136
125
def test_recv__reconnects_on_dirty_conn (self ):
137
126
@@ -143,18 +132,7 @@ def test_recv__reconnects_on_dirty_conn(self):
143
132
144
133
# Now test that recv'ing attempts to reconnect
145
134
self .assertEqual (socket .create_connection .call_count , 0 )
146
- self .conn ._read_bytes (len (self .config ['payload' ]))
147
- self .assertEqual (socket .create_connection .call_count , 1 )
148
-
149
- # A second way to dirty it...
150
- self .conn .close ()
151
-
152
- # Reset the socket call counts
153
- socket .create_connection .reset_mock ()
154
- self .assertEqual (socket .create_connection .call_count , 0 )
155
-
156
- # Now test that recv'ing attempts to reconnect
157
- self .conn ._read_bytes (len (self .config ['payload' ]))
135
+ self .conn .recv (self .config ['request_id' ])
158
136
self .assertEqual (socket .create_connection .call_count , 1 )
159
137
160
138
def test_recv__failure_sets_dirty_connection (self ):
@@ -171,24 +149,10 @@ def raise_error(*args):
171
149
self .assertIsNone (self .conn ._sock )
172
150
173
151
def test_recv__doesnt_consume_extra_data_in_stream (self ):
174
- data1 = self .config ['payload' ]
175
- size1 = len (data1 )
176
- encoded1 = struct .pack ('>i%ds' % size1 , size1 , data1 )
177
- data2 = "an extra payload"
178
- size2 = len (data2 )
179
- encoded2 = struct .pack ('>i%ds' % size2 , size2 , data2 )
180
-
181
- self .conn ._recv_buffer = encoded1
182
- self .conn ._recv_buffer += encoded2
183
-
184
- def mock_socket_recv (num_bytes ):
185
- data = self .conn ._recv_buffer [0 :num_bytes ]
186
- self .conn ._recv_buffer = self .conn ._recv_buffer [num_bytes :]
187
- return data
188
-
189
- with mock .patch .object (self .conn ._sock , 'recv' , new = mock_socket_recv ):
190
- self .assertEquals (self .conn .recv (self .config ['request_id' ]), self .config ['payload' ])
191
- self .assertEquals (str (self .conn ._recv_buffer ), encoded2 )
152
+
153
+ # Here just test that each call to recv will return a single payload
154
+ self .assertEquals (self .conn .recv (self .config ['request_id' ]), self .config ['payload' ])
155
+ self .assertEquals (self .conn .recv (self .config ['request_id' ]), self .config ['payload2' ])
192
156
193
157
def test_close__object_is_reusable (self ):
194
158
0 commit comments