@@ -84,6 +84,7 @@ def __init__(self, hosts, *args, sniff_on_start=False, **kwargs):
84
84
self .sniffing_task = None
85
85
self .loop = None
86
86
self ._async_init_called = False
87
+ self ._sniff_on_start_event = None # type: asyncio.Event
87
88
88
89
super (AsyncTransport , self ).__init__ (
89
90
* args , hosts = [], sniff_on_start = False , ** kwargs
@@ -112,14 +113,35 @@ async def _async_init(self):
112
113
self .loop = get_running_loop ()
113
114
self .kwargs ["loop" ] = self .loop
114
115
115
- # Now that we have a loop we can create all our HTTP connections
116
+ # Now that we have a loop we can create all our HTTP connections...
116
117
self .set_connections (self .hosts )
117
118
self .seed_connections = list (self .connection_pool .connections [:])
118
119
119
120
# ... and we can start sniffing in the background.
120
121
if self .sniffing_task is None and self .sniff_on_start :
121
- self .last_sniff = self .loop .time ()
122
- self .create_sniff_task (initial = True )
122
+
123
+ # Create an asyncio.Event for future calls to block on
124
+ # until the initial sniffing task completes.
125
+ self ._sniff_on_start_event = asyncio .Event ()
126
+
127
+ try :
128
+ self .last_sniff = self .loop .time ()
129
+ self .create_sniff_task (initial = True )
130
+
131
+ # Since this is the first one we wait for it to complete
132
+ # in case there's an error it'll get raised here.
133
+ await self .sniffing_task
134
+
135
+ # If the task gets cancelled here it likely means the
136
+ # transport got closed.
137
+ except asyncio .CancelledError :
138
+ pass
139
+
140
+ # Once we exit this section we want to unblock any _async_calls()
141
+ # that are blocking on our initial sniff attempt regardless of it
142
+ # was successful or not.
143
+ finally :
144
+ self ._sniff_on_start_event .set ()
123
145
124
146
async def _async_call (self ):
125
147
"""This method is called within any async method of AsyncTransport
@@ -130,6 +152,14 @@ async def _async_call(self):
130
152
self ._async_init_called = True
131
153
await self ._async_init ()
132
154
155
+ # If the initial sniff_on_start hasn't returned yet
156
+ # then we need to wait for node information to come back
157
+ # or for the task to be cancelled via AsyncTransport.close()
158
+ if self ._sniff_on_start_event and not self ._sniff_on_start_event .is_set ():
159
+ # This is already a no-op if the event is set but we try to
160
+ # avoid an 'await' by checking 'not event.is_set()' above first.
161
+ await self ._sniff_on_start_event .wait ()
162
+
133
163
if self .sniffer_timeout :
134
164
if self .loop .time () >= self .last_sniff + self .sniffer_timeout :
135
165
self .create_sniff_task ()
@@ -187,6 +217,12 @@ def _sniff_request(conn):
187
217
for t in done :
188
218
try :
189
219
_ , headers , node_info = t .result ()
220
+
221
+ # Lowercase all the header names for consistency in accessing them.
222
+ headers = {
223
+ header .lower (): value for header , value in headers .items ()
224
+ }
225
+
190
226
node_info = self .deserializer .loads (
191
227
node_info , headers .get ("content-type" )
192
228
)
@@ -212,6 +248,8 @@ async def sniff_hosts(self, initial=False):
212
248
"""
213
249
# Without a loop we can't do anything.
214
250
if not self .loop :
251
+ if initial :
252
+ raise RuntimeError ("Event loop not running on initial sniffing task" )
215
253
return
216
254
217
255
node_info = await self ._get_sniff_data (initial )
@@ -293,7 +331,7 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
293
331
connection = self .get_connection ()
294
332
295
333
try :
296
- status , headers , data = await connection .perform_request (
334
+ status , headers_response , data = await connection .perform_request (
297
335
method ,
298
336
url ,
299
337
params ,
@@ -302,6 +340,11 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
302
340
ignore = ignore ,
303
341
timeout = timeout ,
304
342
)
343
+
344
+ # Lowercase all the header names for consistency in accessing them.
345
+ headers_response = {
346
+ header .lower (): value for header , value in headers_response .items ()
347
+ }
305
348
except TransportError as e :
306
349
if method == "HEAD" and e .status_code == 404 :
307
350
return False
@@ -336,7 +379,9 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
336
379
return 200 <= status < 300
337
380
338
381
if data :
339
- data = self .deserializer .loads (data , headers .get ("content-type" ))
382
+ data = self .deserializer .loads (
383
+ data , headers_response .get ("content-type" )
384
+ )
340
385
return data
341
386
342
387
async def close (self ):
@@ -350,5 +395,6 @@ async def close(self):
350
395
except asyncio .CancelledError :
351
396
pass
352
397
self .sniffing_task = None
398
+
353
399
for connection in self .connection_pool .connections :
354
400
await connection .close ()
0 commit comments