Skip to content

Commit ac4cbe3

Browse files
tonybaloneygavin-aguiarvrdmr
authored
ASGI lifecycle events (#187)
* Yield a disconnect on the second receive call * Integrate ASGI startup and shutdown lifecycle events * Code cleanup * Support startup and shutdown events. Capture failures. Use asyncio events to prevent deadlocks * Use debug * Lint code and fix some tests * Run the shutdown as a task * Use a type variable for the receive queue * Direct imports * Initialize event loops * Initialize queues and signals inside coroutines to ensure the same event loops are used. Check call sequence at runtime * Test all sequences and events * Verify sequence runtime error is thrown * Check startup was called before calling destroy in __del__ magic * Will no longer execute requests if the startup failed, will retry. * Test more failure scenarios --------- Co-authored-by: gavin-aguiar <[email protected]> Co-authored-by: Varad Meru <[email protected]>
1 parent b42035a commit ac4cbe3

File tree

3 files changed

+233
-57
lines changed

3 files changed

+233
-57
lines changed

azure/functions/_http_asgi.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44
from typing import Dict, List, Tuple, Optional, Any, Union
55
import logging
66
import asyncio
7+
from asyncio import Event, Queue
78
from warnings import warn
89
from wsgiref.headers import Headers
910

1011
from ._abc import Context
1112
from ._http import HttpRequest, HttpResponse
1213
from ._http_wsgi import WsgiRequest
1314

15+
ASGI_VERSION = "2.1"
16+
ASGI_SPEC_VERSION = "2.1"
17+
1418

1519
class AsgiRequest(WsgiRequest):
1620
def __init__(self, func_req: HttpRequest,
1721
func_ctx: Optional[Context] = None):
18-
self.asgi_version = "2.1"
19-
self.asgi_spec_version = "2.1"
22+
self.asgi_version = ASGI_VERSION
23+
self.asgi_spec_version = ASGI_SPEC_VERSION
2024
self._headers = func_req.headers
2125
super().__init__(func_req, func_ctx)
2226

@@ -153,6 +157,11 @@ def __init__(self, app):
153157

154158
self._app = app
155159
self.main = self._handle
160+
self.state = {}
161+
self.lifespan_receive_queue: Optional[Queue] = None
162+
self.lifespan_startup_event: Optional[Event] = None
163+
self.lifespan_shutdown_event: Optional[Event] = None
164+
self._startup_succeeded = False
156165

157166
def handle(self, req: HttpRequest, context: Optional[Context] = None):
158167
"""Deprecated. Please use handle_async instead:
@@ -203,3 +212,76 @@ async def _handle_async(self, req, context):
203212
scope,
204213
req.get_body())
205214
return asgi_response.to_func_response()
215+
216+
async def _lifespan_receive(self):
217+
if not self.lifespan_receive_queue:
218+
raise RuntimeError("notify_startup() must be called first.")
219+
return await self.lifespan_receive_queue.get()
220+
221+
async def _lifespan_send(self, message):
222+
logging.debug("Received lifespan message %s.", message)
223+
if not self.lifespan_startup_event or not self.lifespan_shutdown_event:
224+
raise RuntimeError("notify_startup() must be called first.")
225+
if message["type"] == "lifespan.startup.complete":
226+
self.lifespan_startup_event.set()
227+
self._startup_succeeded = True
228+
elif message["type"] == "lifespan.shutdown.complete":
229+
self.lifespan_shutdown_event.set()
230+
elif message["type"] == "lifespan.startup.failed":
231+
self.lifespan_startup_event.set()
232+
self._startup_succeeded = False
233+
if message.get("message"):
234+
self._logger.error("Failed ASGI startup with message '%s'.",
235+
message["message"])
236+
else:
237+
self._logger.error("Failed ASGI startup event.")
238+
elif message["type"] == "lifespan.shutdown.failed":
239+
self.lifespan_shutdown_event.set()
240+
if message.get("message"):
241+
self._logger.error("Failed ASGI shutdown with message '%s'.",
242+
message["message"])
243+
else:
244+
self._logger.error("Failed ASGI shutdown event.")
245+
246+
async def _lifespan_main(self):
247+
scope = {
248+
"type": "lifespan",
249+
"asgi.version": ASGI_VERSION,
250+
"asgi.spec_version": ASGI_SPEC_VERSION,
251+
"state": self.state,
252+
}
253+
if not self.lifespan_startup_event or not self.lifespan_shutdown_event:
254+
raise RuntimeError("notify_startup() must be called first.")
255+
try:
256+
await self._app(scope, self._lifespan_receive, self._lifespan_send)
257+
finally:
258+
self.lifespan_startup_event.set()
259+
self.lifespan_shutdown_event.set()
260+
261+
async def notify_startup(self):
262+
"""Notify the ASGI app that the server has started."""
263+
self._logger.debug("Notifying ASGI app of startup.")
264+
265+
# Initialize signals and queues
266+
if not self.lifespan_receive_queue:
267+
self.lifespan_receive_queue = Queue()
268+
if not self.lifespan_startup_event:
269+
self.lifespan_startup_event = Event()
270+
if not self.lifespan_shutdown_event:
271+
self.lifespan_shutdown_event = Event()
272+
273+
startup_event = {"type": "lifespan.startup"}
274+
await self.lifespan_receive_queue.put(startup_event)
275+
task = asyncio.create_task(self._lifespan_main()) # NOQA
276+
await self.lifespan_startup_event.wait()
277+
return self._startup_succeeded
278+
279+
async def notify_shutdown(self):
280+
"""Notify the ASGI app that the server is shutting down."""
281+
if not self.lifespan_receive_queue or not self.lifespan_shutdown_event:
282+
raise RuntimeError("notify_startup() must be called first.")
283+
284+
self._logger.debug("Notifying ASGI app of shutdown.")
285+
shutdown_event = {"type": "lifespan.shutdown"}
286+
await self.lifespan_receive_queue.put(shutdown_event)
287+
await self.lifespan_shutdown_event.wait()

azure/functions/decorators/function_app.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33
import abc
4+
import asyncio
45
import json
56
import logging
67
from abc import ABC
@@ -2774,7 +2775,13 @@ def __init__(self, app,
27742775
on the request in order to invoke the function.
27752776
"""
27762777
super().__init__(auth_level=http_auth_level)
2777-
self._add_http_app(AsgiMiddleware(app))
2778+
self.middleware = AsgiMiddleware(app)
2779+
self._add_http_app(self.middleware)
2780+
self.startup_task_done = False
2781+
2782+
def __del__(self):
2783+
if self.startup_task_done:
2784+
asyncio.run(self.middleware.notify_shutdown())
27782785

27792786
def _add_http_app(self,
27802787
http_middleware: Union[
@@ -2797,6 +2804,12 @@ def _add_http_app(self,
27972804
auth_level=self.auth_level,
27982805
route="/{*route}")
27992806
async def http_app_func(req: HttpRequest, context: Context):
2807+
if not self.startup_task_done:
2808+
success = await asgi_middleware.notify_startup()
2809+
if not success:
2810+
raise RuntimeError("ASGI middleware startup failed.")
2811+
self.startup_task_done = True
2812+
28002813
return await asgi_middleware.handle_async(req, context)
28012814

28022815

tests/test_http_asgi.py

Lines changed: 135 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from azure.functions._http_asgi import (
1111
AsgiMiddleware
1212
)
13+
import pytest
1314

1415

1516
class MockAsgiApplication:
@@ -18,71 +19,109 @@ class MockAsgiApplication:
1819
response_headers = [
1920
[b"content-type", b"text/plain"],
2021
]
22+
startup_called = False
23+
shutdown_called = False
24+
25+
def __init__(self, fail_startup=False, fail_shutdown=False):
26+
self.fail_startup = fail_startup
27+
self.fail_shutdown = fail_shutdown
2128

2229
async def __call__(self, scope, receive, send):
2330
self.received_scope = scope
24-
# Verify against ASGI specification
25-
assert scope['type'] == 'http'
26-
assert isinstance(scope['type'], str)
2731

32+
# Verify against ASGI specification
2833
assert scope['asgi.spec_version'] in ['2.0', '2.1']
2934
assert isinstance(scope['asgi.spec_version'], str)
3035

3136
assert scope['asgi.version'] in ['2.0', '2.1', '2.2']
3237
assert isinstance(scope['asgi.version'], str)
3338

34-
assert scope['http_version'] in ['1.0', '1.1', '2']
35-
assert isinstance(scope['http_version'], str)
36-
37-
assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH']
38-
assert isinstance(scope['method'], str)
39-
40-
assert scope['scheme'] in ['http', 'https']
41-
assert isinstance(scope['scheme'], str)
42-
43-
assert isinstance(scope['path'], str)
44-
assert isinstance(scope['raw_path'], bytes)
45-
assert isinstance(scope['query_string'], bytes)
46-
assert isinstance(scope['root_path'], str)
47-
48-
assert hasattr(scope['headers'], '__iter__')
49-
for k, v in scope['headers']:
50-
assert isinstance(k, bytes)
51-
assert isinstance(v, bytes)
52-
53-
assert scope['client'] is None or hasattr(scope['client'], '__iter__')
54-
if scope['client']:
55-
assert len(scope['client']) == 2
56-
assert isinstance(scope['client'][0], str)
57-
assert isinstance(scope['client'][1], int)
58-
59-
assert scope['server'] is None or hasattr(scope['server'], '__iter__')
60-
if scope['server']:
61-
assert len(scope['server']) == 2
62-
assert isinstance(scope['server'][0], str)
63-
assert isinstance(scope['server'][1], int)
64-
65-
self.received_request = await receive()
66-
assert self.received_request['type'] == 'http.request'
67-
assert isinstance(self.received_request['body'], bytes)
68-
assert isinstance(self.received_request['more_body'], bool)
69-
70-
await send(
71-
{
72-
"type": "http.response.start",
73-
"status": self.response_code,
74-
"headers": self.response_headers,
75-
}
76-
)
77-
await send(
78-
{
79-
"type": "http.response.body",
80-
"body": self.response_body,
81-
}
82-
)
39+
assert isinstance(scope['type'], str)
8340

84-
self.next_request = await receive()
85-
assert self.next_request['type'] == 'http.disconnect'
41+
if scope['type'] == 'lifespan':
42+
self.startup_called = True
43+
startup_message = await receive()
44+
assert startup_message['type'] == 'lifespan.startup'
45+
if self.fail_startup:
46+
if isinstance(self.fail_startup, str):
47+
await send({
48+
"type": "lifespan.startup.failed",
49+
"message": self.fail_startup})
50+
else:
51+
await send({"type": "lifespan.startup.failed"})
52+
else:
53+
await send({"type": "lifespan.startup.complete"})
54+
shutdown_message = await receive()
55+
assert shutdown_message['type'] == 'lifespan.shutdown'
56+
if self.fail_shutdown:
57+
if isinstance(self.fail_shutdown, str):
58+
await send({
59+
"type": "lifespan.shutdown.failed",
60+
"message": self.fail_shutdown})
61+
else:
62+
await send({"type": "lifespan.shutdown.failed"})
63+
else:
64+
await send({"type": "lifespan.shutdown.complete"})
65+
66+
self.shutdown_called = True
67+
68+
elif scope['type'] == 'http':
69+
assert scope['http_version'] in ['1.0', '1.1', '2']
70+
assert isinstance(scope['http_version'], str)
71+
72+
assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH']
73+
assert isinstance(scope['method'], str)
74+
75+
assert scope['scheme'] in ['http', 'https']
76+
assert isinstance(scope['scheme'], str)
77+
78+
assert isinstance(scope['path'], str)
79+
assert isinstance(scope['raw_path'], bytes)
80+
assert isinstance(scope['query_string'], bytes)
81+
assert isinstance(scope['root_path'], str)
82+
83+
assert hasattr(scope['headers'], '__iter__')
84+
for k, v in scope['headers']:
85+
assert isinstance(k, bytes)
86+
assert isinstance(v, bytes)
87+
88+
assert scope['client'] is None or hasattr(scope['client'],
89+
'__iter__')
90+
if scope['client']:
91+
assert len(scope['client']) == 2
92+
assert isinstance(scope['client'][0], str)
93+
assert isinstance(scope['client'][1], int)
94+
95+
assert scope['server'] is None or hasattr(scope['server'],
96+
'__iter__')
97+
if scope['server']:
98+
assert len(scope['server']) == 2
99+
assert isinstance(scope['server'][0], str)
100+
assert isinstance(scope['server'][1], int)
101+
102+
self.received_request = await receive()
103+
assert self.received_request['type'] == 'http.request'
104+
assert isinstance(self.received_request['body'], bytes)
105+
assert isinstance(self.received_request['more_body'], bool)
106+
107+
await send(
108+
{
109+
"type": "http.response.start",
110+
"status": self.response_code,
111+
"headers": self.response_headers,
112+
}
113+
)
114+
await send(
115+
{
116+
"type": "http.response.body",
117+
"body": self.response_body,
118+
}
119+
)
120+
121+
self.next_request = await receive()
122+
assert self.next_request['type'] == 'http.disconnect'
123+
else:
124+
raise AssertionError(f"unexpected type {scope['type']}")
86125

87126

88127
class TestHttpAsgiMiddleware(unittest.TestCase):
@@ -221,3 +260,45 @@ async def main(req, context):
221260
# Verify asserted
222261
self.assertEqual(response.status_code, 200)
223262
self.assertEqual(response.get_body(), test_body)
263+
264+
def test_function_app_lifecycle_events(self):
265+
mock_app = MockAsgiApplication()
266+
middleware = AsgiMiddleware(mock_app)
267+
asyncio.get_event_loop().run_until_complete(
268+
middleware.notify_startup()
269+
)
270+
assert mock_app.startup_called
271+
272+
asyncio.get_event_loop().run_until_complete(
273+
middleware.notify_shutdown()
274+
)
275+
assert mock_app.shutdown_called
276+
277+
def test_function_app_lifecycle_events_with_failures(self):
278+
apps = [
279+
MockAsgiApplication(False, True),
280+
MockAsgiApplication(True, False),
281+
MockAsgiApplication(True, True),
282+
MockAsgiApplication("bork", False),
283+
MockAsgiApplication(False, "bork"),
284+
MockAsgiApplication("bork", "bork"),
285+
]
286+
for mock_app in apps:
287+
middleware = AsgiMiddleware(mock_app)
288+
asyncio.get_event_loop().run_until_complete(
289+
middleware.notify_startup()
290+
)
291+
assert mock_app.startup_called
292+
293+
asyncio.get_event_loop().run_until_complete(
294+
middleware.notify_shutdown()
295+
)
296+
assert mock_app.shutdown_called
297+
298+
def test_calling_shutdown_without_startup_errors(self):
299+
mock_app = MockAsgiApplication()
300+
middleware = AsgiMiddleware(mock_app)
301+
with pytest.raises(RuntimeError):
302+
asyncio.get_event_loop().run_until_complete(
303+
middleware.notify_shutdown()
304+
)

0 commit comments

Comments
 (0)