Skip to content

ASGI lifecycle events #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a2ad67c
Yield a disconnect on the second receive call
tonybaloney Jan 11, 2022
d8fc213
Merge remote-tracking branch 'upstream/dev' into dev
tonybaloney May 23, 2022
b9cd37a
Merge branch 'Azure:dev' into dev
tonybaloney Oct 9, 2022
ad52ad5
Merge remote-tracking branch 'upstream/dev' into dev
tonybaloney Nov 1, 2022
e6c925c
Integrate ASGI startup and shutdown lifecycle events
tonybaloney Jul 18, 2023
be4dffd
Merge branch 'dev' into asgi_startup_lifecycles
tonybaloney Jul 18, 2023
734e2ac
Code cleanup
tonybaloney Jul 18, 2023
2083a48
Support startup and shutdown events. Capture failures. Use asyncio ev…
tonybaloney Jul 18, 2023
be386e5
Use debug
tonybaloney Jul 18, 2023
1efb812
Lint code and fix some tests
tonybaloney Jul 18, 2023
fc1f829
Run the shutdown as a task
tonybaloney Jul 18, 2023
1c8b7de
Use a type variable for the receive queue
tonybaloney Jul 18, 2023
ee1fadd
Direct imports
tonybaloney Jul 25, 2023
fb457ec
Initialize event loops
tonybaloney Jul 25, 2023
ca836de
Initialize queues and signals inside coroutines to ensure the same ev…
tonybaloney Jul 25, 2023
502b2ca
Test all sequences and events
tonybaloney Jul 25, 2023
246afd2
Verify sequence runtime error is thrown
tonybaloney Jul 25, 2023
c88123e
Check startup was called before calling destroy in __del__ magic
tonybaloney Jul 25, 2023
d4788fe
Will no longer execute requests if the startup failed, will retry.
tonybaloney Jul 25, 2023
60ae50f
Test more failure scenarios
tonybaloney Jul 25, 2023
d86c622
Merge branch 'dev' into asgi_startup_lifecycles
tonybaloney Jul 31, 2023
fe58fdf
Merge branch 'dev' into asgi_startup_lifecycles
tonybaloney Aug 7, 2023
7c129cf
Merge branch 'dev' into asgi_startup_lifecycles
tonybaloney Aug 31, 2023
1b88383
Merge branch 'dev' into asgi_startup_lifecycles
gavin-aguiar Sep 1, 2023
b94205a
Merge branch 'dev' into asgi_startup_lifecycles
tonybaloney Oct 24, 2023
8113f1b
Merge branch 'dev' into asgi_startup_lifecycles
vrdmr Nov 14, 2023
f083cf8
Merge branch 'dev' into asgi_startup_lifecycles
gavin-aguiar Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 84 additions & 2 deletions azure/functions/_http_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from typing import Dict, List, Tuple, Optional, Any, Union
import logging
import asyncio
from asyncio import Event, Queue
from warnings import warn
from wsgiref.headers import Headers

from ._abc import Context
from ._http import HttpRequest, HttpResponse
from ._http_wsgi import WsgiRequest

ASGI_VERSION = "2.1"
ASGI_SPEC_VERSION = "2.1"


class AsgiRequest(WsgiRequest):
def __init__(self, func_req: HttpRequest,
func_ctx: Optional[Context] = None):
self.asgi_version = "2.1"
self.asgi_spec_version = "2.1"
self.asgi_version = ASGI_VERSION
self.asgi_spec_version = ASGI_SPEC_VERSION
self._headers = func_req.headers
super().__init__(func_req, func_ctx)

Expand Down Expand Up @@ -153,6 +157,11 @@ def __init__(self, app):

self._app = app
self.main = self._handle
self.state = {}
self.lifespan_receive_queue: Optional[Queue] = None
self.lifespan_startup_event: Optional[Event] = None
self.lifespan_shutdown_event: Optional[Event] = None
self._startup_succeeded = False

def handle(self, req: HttpRequest, context: Optional[Context] = None):
"""Deprecated. Please use handle_async instead:
Expand Down Expand Up @@ -203,3 +212,76 @@ async def _handle_async(self, req, context):
scope,
req.get_body())
return asgi_response.to_func_response()

async def _lifespan_receive(self):
if not self.lifespan_receive_queue:
raise RuntimeError("notify_startup() must be called first.")
return await self.lifespan_receive_queue.get()

async def _lifespan_send(self, message):
logging.debug("Received lifespan message %s.", message)
if not self.lifespan_startup_event or not self.lifespan_shutdown_event:
raise RuntimeError("notify_startup() must be called first.")
if message["type"] == "lifespan.startup.complete":
self.lifespan_startup_event.set()
self._startup_succeeded = True
elif message["type"] == "lifespan.shutdown.complete":
self.lifespan_shutdown_event.set()
elif message["type"] == "lifespan.startup.failed":
self.lifespan_startup_event.set()
self._startup_succeeded = False
if message.get("message"):
self._logger.error("Failed ASGI startup with message '%s'.",
message["message"])
else:
self._logger.error("Failed ASGI startup event.")
elif message["type"] == "lifespan.shutdown.failed":
self.lifespan_shutdown_event.set()
if message.get("message"):
self._logger.error("Failed ASGI shutdown with message '%s'.",
message["message"])
else:
self._logger.error("Failed ASGI shutdown event.")

async def _lifespan_main(self):
scope = {
"type": "lifespan",
"asgi.version": ASGI_VERSION,
"asgi.spec_version": ASGI_SPEC_VERSION,
"state": self.state,
}
if not self.lifespan_startup_event or not self.lifespan_shutdown_event:
raise RuntimeError("notify_startup() must be called first.")
try:
await self._app(scope, self._lifespan_receive, self._lifespan_send)
finally:
self.lifespan_startup_event.set()
self.lifespan_shutdown_event.set()

async def notify_startup(self):
"""Notify the ASGI app that the server has started."""
self._logger.debug("Notifying ASGI app of startup.")

# Initialize signals and queues
if not self.lifespan_receive_queue:
self.lifespan_receive_queue = Queue()
if not self.lifespan_startup_event:
self.lifespan_startup_event = Event()
if not self.lifespan_shutdown_event:
self.lifespan_shutdown_event = Event()

startup_event = {"type": "lifespan.startup"}
await self.lifespan_receive_queue.put(startup_event)
task = asyncio.create_task(self._lifespan_main()) # NOQA
await self.lifespan_startup_event.wait()
return self._startup_succeeded

async def notify_shutdown(self):
"""Notify the ASGI app that the server is shutting down."""
if not self.lifespan_receive_queue or not self.lifespan_shutdown_event:
raise RuntimeError("notify_startup() must be called first.")

self._logger.debug("Notifying ASGI app of shutdown.")
shutdown_event = {"type": "lifespan.shutdown"}
await self.lifespan_receive_queue.put(shutdown_event)
await self.lifespan_shutdown_event.wait()
15 changes: 14 additions & 1 deletion azure/functions/decorators/function_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import abc
import asyncio
import json
import logging
from abc import ABC
Expand Down Expand Up @@ -2774,7 +2775,13 @@ def __init__(self, app,
on the request in order to invoke the function.
"""
super().__init__(auth_level=http_auth_level)
self._add_http_app(AsgiMiddleware(app))
self.middleware = AsgiMiddleware(app)
self._add_http_app(self.middleware)
self.startup_task_done = False

def __del__(self):
if self.startup_task_done:
asyncio.run(self.middleware.notify_shutdown())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why for shutdown, we dont raise runtimeerror if shutdown does not succeed (similar to startup)?


def _add_http_app(self,
http_middleware: Union[
Expand All @@ -2797,6 +2804,12 @@ def _add_http_app(self,
auth_level=self.auth_level,
route="/{*route}")
async def http_app_func(req: HttpRequest, context: Context):
if not self.startup_task_done:
success = await asgi_middleware.notify_startup()
if not success:
raise RuntimeError("ASGI middleware startup failed.")
self.startup_task_done = True

return await asgi_middleware.handle_async(req, context)


Expand Down
189 changes: 135 additions & 54 deletions tests/test_http_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from azure.functions._http_asgi import (
AsgiMiddleware
)
import pytest


class MockAsgiApplication:
Expand All @@ -18,71 +19,109 @@ class MockAsgiApplication:
response_headers = [
[b"content-type", b"text/plain"],
]
startup_called = False
shutdown_called = False

def __init__(self, fail_startup=False, fail_shutdown=False):
self.fail_startup = fail_startup
self.fail_shutdown = fail_shutdown

async def __call__(self, scope, receive, send):
self.received_scope = scope
# Verify against ASGI specification
assert scope['type'] == 'http'
assert isinstance(scope['type'], str)

# Verify against ASGI specification
assert scope['asgi.spec_version'] in ['2.0', '2.1']
assert isinstance(scope['asgi.spec_version'], str)

assert scope['asgi.version'] in ['2.0', '2.1', '2.2']
assert isinstance(scope['asgi.version'], str)

assert scope['http_version'] in ['1.0', '1.1', '2']
assert isinstance(scope['http_version'], str)

assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH']
assert isinstance(scope['method'], str)

assert scope['scheme'] in ['http', 'https']
assert isinstance(scope['scheme'], str)

assert isinstance(scope['path'], str)
assert isinstance(scope['raw_path'], bytes)
assert isinstance(scope['query_string'], bytes)
assert isinstance(scope['root_path'], str)

assert hasattr(scope['headers'], '__iter__')
for k, v in scope['headers']:
assert isinstance(k, bytes)
assert isinstance(v, bytes)

assert scope['client'] is None or hasattr(scope['client'], '__iter__')
if scope['client']:
assert len(scope['client']) == 2
assert isinstance(scope['client'][0], str)
assert isinstance(scope['client'][1], int)

assert scope['server'] is None or hasattr(scope['server'], '__iter__')
if scope['server']:
assert len(scope['server']) == 2
assert isinstance(scope['server'][0], str)
assert isinstance(scope['server'][1], int)

self.received_request = await receive()
assert self.received_request['type'] == 'http.request'
assert isinstance(self.received_request['body'], bytes)
assert isinstance(self.received_request['more_body'], bool)

await send(
{
"type": "http.response.start",
"status": self.response_code,
"headers": self.response_headers,
}
)
await send(
{
"type": "http.response.body",
"body": self.response_body,
}
)
assert isinstance(scope['type'], str)

self.next_request = await receive()
assert self.next_request['type'] == 'http.disconnect'
if scope['type'] == 'lifespan':
self.startup_called = True
startup_message = await receive()
assert startup_message['type'] == 'lifespan.startup'
if self.fail_startup:
if isinstance(self.fail_startup, str):
await send({
"type": "lifespan.startup.failed",
"message": self.fail_startup})
else:
await send({"type": "lifespan.startup.failed"})
else:
await send({"type": "lifespan.startup.complete"})
shutdown_message = await receive()
assert shutdown_message['type'] == 'lifespan.shutdown'
if self.fail_shutdown:
if isinstance(self.fail_shutdown, str):
await send({
"type": "lifespan.shutdown.failed",
"message": self.fail_shutdown})
else:
await send({"type": "lifespan.shutdown.failed"})
else:
await send({"type": "lifespan.shutdown.complete"})

self.shutdown_called = True

elif scope['type'] == 'http':
assert scope['http_version'] in ['1.0', '1.1', '2']
assert isinstance(scope['http_version'], str)

assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH']
assert isinstance(scope['method'], str)

assert scope['scheme'] in ['http', 'https']
assert isinstance(scope['scheme'], str)

assert isinstance(scope['path'], str)
assert isinstance(scope['raw_path'], bytes)
assert isinstance(scope['query_string'], bytes)
assert isinstance(scope['root_path'], str)

assert hasattr(scope['headers'], '__iter__')
for k, v in scope['headers']:
assert isinstance(k, bytes)
assert isinstance(v, bytes)

assert scope['client'] is None or hasattr(scope['client'],
'__iter__')
if scope['client']:
assert len(scope['client']) == 2
assert isinstance(scope['client'][0], str)
assert isinstance(scope['client'][1], int)

assert scope['server'] is None or hasattr(scope['server'],
'__iter__')
if scope['server']:
assert len(scope['server']) == 2
assert isinstance(scope['server'][0], str)
assert isinstance(scope['server'][1], int)

self.received_request = await receive()
assert self.received_request['type'] == 'http.request'
assert isinstance(self.received_request['body'], bytes)
assert isinstance(self.received_request['more_body'], bool)

await send(
{
"type": "http.response.start",
"status": self.response_code,
"headers": self.response_headers,
}
)
await send(
{
"type": "http.response.body",
"body": self.response_body,
}
)

self.next_request = await receive()
assert self.next_request['type'] == 'http.disconnect'
else:
raise AssertionError(f"unexpected type {scope['type']}")


class TestHttpAsgiMiddleware(unittest.TestCase):
Expand Down Expand Up @@ -221,3 +260,45 @@ async def main(req, context):
# Verify asserted
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_body(), test_body)

def test_function_app_lifecycle_events(self):
mock_app = MockAsgiApplication()
middleware = AsgiMiddleware(mock_app)
asyncio.get_event_loop().run_until_complete(
middleware.notify_startup()
)
assert mock_app.startup_called

asyncio.get_event_loop().run_until_complete(
middleware.notify_shutdown()
)
assert mock_app.shutdown_called

def test_function_app_lifecycle_events_with_failures(self):
apps = [
MockAsgiApplication(False, True),
MockAsgiApplication(True, False),
MockAsgiApplication(True, True),
MockAsgiApplication("bork", False),
MockAsgiApplication(False, "bork"),
MockAsgiApplication("bork", "bork"),
]
for mock_app in apps:
middleware = AsgiMiddleware(mock_app)
asyncio.get_event_loop().run_until_complete(
middleware.notify_startup()
)
assert mock_app.startup_called

asyncio.get_event_loop().run_until_complete(
middleware.notify_shutdown()
)
assert mock_app.shutdown_called

def test_calling_shutdown_without_startup_errors(self):
mock_app = MockAsgiApplication()
middleware = AsgiMiddleware(mock_app)
with pytest.raises(RuntimeError):
asyncio.get_event_loop().run_until_complete(
middleware.notify_shutdown()
)