diff --git a/azure/functions/_http_asgi.py b/azure/functions/_http_asgi.py index 9f350c07..7be3101c 100644 --- a/azure/functions/_http_asgi.py +++ b/azure/functions/_http_asgi.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Optional, Any, Union import logging import asyncio +from asyncio import Event, Queue from warnings import warn from wsgiref.headers import Headers @@ -11,12 +12,15 @@ from ._http import HttpRequest, HttpResponse from ._http_wsgi import WsgiRequest +ASGI_VERSION = "2.1" +ASGI_SPEC_VERSION = "2.1" + class AsgiRequest(WsgiRequest): def __init__(self, func_req: HttpRequest, func_ctx: Optional[Context] = None): - self.asgi_version = "2.1" - self.asgi_spec_version = "2.1" + self.asgi_version = ASGI_VERSION + self.asgi_spec_version = ASGI_SPEC_VERSION self._headers = func_req.headers super().__init__(func_req, func_ctx) @@ -153,6 +157,11 @@ def __init__(self, app): self._app = app self.main = self._handle + self.state = {} + self.lifespan_receive_queue: Optional[Queue] = None + self.lifespan_startup_event: Optional[Event] = None + self.lifespan_shutdown_event: Optional[Event] = None + self._startup_succeeded = False def handle(self, req: HttpRequest, context: Optional[Context] = None): """Deprecated. Please use handle_async instead: @@ -203,3 +212,76 @@ async def _handle_async(self, req, context): scope, req.get_body()) return asgi_response.to_func_response() + + async def _lifespan_receive(self): + if not self.lifespan_receive_queue: + raise RuntimeError("notify_startup() must be called first.") + return await self.lifespan_receive_queue.get() + + async def _lifespan_send(self, message): + logging.debug("Received lifespan message %s.", message) + if not self.lifespan_startup_event or not self.lifespan_shutdown_event: + raise RuntimeError("notify_startup() must be called first.") + if message["type"] == "lifespan.startup.complete": + self.lifespan_startup_event.set() + self._startup_succeeded = True + elif message["type"] == "lifespan.shutdown.complete": + self.lifespan_shutdown_event.set() + elif message["type"] == "lifespan.startup.failed": + self.lifespan_startup_event.set() + self._startup_succeeded = False + if message.get("message"): + self._logger.error("Failed ASGI startup with message '%s'.", + message["message"]) + else: + self._logger.error("Failed ASGI startup event.") + elif message["type"] == "lifespan.shutdown.failed": + self.lifespan_shutdown_event.set() + if message.get("message"): + self._logger.error("Failed ASGI shutdown with message '%s'.", + message["message"]) + else: + self._logger.error("Failed ASGI shutdown event.") + + async def _lifespan_main(self): + scope = { + "type": "lifespan", + "asgi.version": ASGI_VERSION, + "asgi.spec_version": ASGI_SPEC_VERSION, + "state": self.state, + } + if not self.lifespan_startup_event or not self.lifespan_shutdown_event: + raise RuntimeError("notify_startup() must be called first.") + try: + await self._app(scope, self._lifespan_receive, self._lifespan_send) + finally: + self.lifespan_startup_event.set() + self.lifespan_shutdown_event.set() + + async def notify_startup(self): + """Notify the ASGI app that the server has started.""" + self._logger.debug("Notifying ASGI app of startup.") + + # Initialize signals and queues + if not self.lifespan_receive_queue: + self.lifespan_receive_queue = Queue() + if not self.lifespan_startup_event: + self.lifespan_startup_event = Event() + if not self.lifespan_shutdown_event: + self.lifespan_shutdown_event = Event() + + startup_event = {"type": "lifespan.startup"} + await self.lifespan_receive_queue.put(startup_event) + task = asyncio.create_task(self._lifespan_main()) # NOQA + await self.lifespan_startup_event.wait() + return self._startup_succeeded + + async def notify_shutdown(self): + """Notify the ASGI app that the server is shutting down.""" + if not self.lifespan_receive_queue or not self.lifespan_shutdown_event: + raise RuntimeError("notify_startup() must be called first.") + + self._logger.debug("Notifying ASGI app of shutdown.") + shutdown_event = {"type": "lifespan.shutdown"} + await self.lifespan_receive_queue.put(shutdown_event) + await self.lifespan_shutdown_event.wait() diff --git a/azure/functions/decorators/function_app.py b/azure/functions/decorators/function_app.py index 95efeb6b..fa7a6314 100644 --- a/azure/functions/decorators/function_app.py +++ b/azure/functions/decorators/function_app.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import abc +import asyncio import json import logging from abc import ABC @@ -2774,7 +2775,13 @@ def __init__(self, app, on the request in order to invoke the function. """ super().__init__(auth_level=http_auth_level) - self._add_http_app(AsgiMiddleware(app)) + self.middleware = AsgiMiddleware(app) + self._add_http_app(self.middleware) + self.startup_task_done = False + + def __del__(self): + if self.startup_task_done: + asyncio.run(self.middleware.notify_shutdown()) def _add_http_app(self, http_middleware: Union[ @@ -2797,6 +2804,12 @@ def _add_http_app(self, auth_level=self.auth_level, route="/{*route}") async def http_app_func(req: HttpRequest, context: Context): + if not self.startup_task_done: + success = await asgi_middleware.notify_startup() + if not success: + raise RuntimeError("ASGI middleware startup failed.") + self.startup_task_done = True + return await asgi_middleware.handle_async(req, context) diff --git a/tests/test_http_asgi.py b/tests/test_http_asgi.py index 073817da..7125cd80 100644 --- a/tests/test_http_asgi.py +++ b/tests/test_http_asgi.py @@ -10,6 +10,7 @@ from azure.functions._http_asgi import ( AsgiMiddleware ) +import pytest class MockAsgiApplication: @@ -18,71 +19,109 @@ class MockAsgiApplication: response_headers = [ [b"content-type", b"text/plain"], ] + startup_called = False + shutdown_called = False + + def __init__(self, fail_startup=False, fail_shutdown=False): + self.fail_startup = fail_startup + self.fail_shutdown = fail_shutdown async def __call__(self, scope, receive, send): self.received_scope = scope - # Verify against ASGI specification - assert scope['type'] == 'http' - assert isinstance(scope['type'], str) + # Verify against ASGI specification assert scope['asgi.spec_version'] in ['2.0', '2.1'] assert isinstance(scope['asgi.spec_version'], str) assert scope['asgi.version'] in ['2.0', '2.1', '2.2'] assert isinstance(scope['asgi.version'], str) - assert scope['http_version'] in ['1.0', '1.1', '2'] - assert isinstance(scope['http_version'], str) - - assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH'] - assert isinstance(scope['method'], str) - - assert scope['scheme'] in ['http', 'https'] - assert isinstance(scope['scheme'], str) - - assert isinstance(scope['path'], str) - assert isinstance(scope['raw_path'], bytes) - assert isinstance(scope['query_string'], bytes) - assert isinstance(scope['root_path'], str) - - assert hasattr(scope['headers'], '__iter__') - for k, v in scope['headers']: - assert isinstance(k, bytes) - assert isinstance(v, bytes) - - assert scope['client'] is None or hasattr(scope['client'], '__iter__') - if scope['client']: - assert len(scope['client']) == 2 - assert isinstance(scope['client'][0], str) - assert isinstance(scope['client'][1], int) - - assert scope['server'] is None or hasattr(scope['server'], '__iter__') - if scope['server']: - assert len(scope['server']) == 2 - assert isinstance(scope['server'][0], str) - assert isinstance(scope['server'][1], int) - - self.received_request = await receive() - assert self.received_request['type'] == 'http.request' - assert isinstance(self.received_request['body'], bytes) - assert isinstance(self.received_request['more_body'], bool) - - await send( - { - "type": "http.response.start", - "status": self.response_code, - "headers": self.response_headers, - } - ) - await send( - { - "type": "http.response.body", - "body": self.response_body, - } - ) + assert isinstance(scope['type'], str) - self.next_request = await receive() - assert self.next_request['type'] == 'http.disconnect' + if scope['type'] == 'lifespan': + self.startup_called = True + startup_message = await receive() + assert startup_message['type'] == 'lifespan.startup' + if self.fail_startup: + if isinstance(self.fail_startup, str): + await send({ + "type": "lifespan.startup.failed", + "message": self.fail_startup}) + else: + await send({"type": "lifespan.startup.failed"}) + else: + await send({"type": "lifespan.startup.complete"}) + shutdown_message = await receive() + assert shutdown_message['type'] == 'lifespan.shutdown' + if self.fail_shutdown: + if isinstance(self.fail_shutdown, str): + await send({ + "type": "lifespan.shutdown.failed", + "message": self.fail_shutdown}) + else: + await send({"type": "lifespan.shutdown.failed"}) + else: + await send({"type": "lifespan.shutdown.complete"}) + + self.shutdown_called = True + + elif scope['type'] == 'http': + assert scope['http_version'] in ['1.0', '1.1', '2'] + assert isinstance(scope['http_version'], str) + + assert scope['method'] in ['POST', 'GET', 'PUT', 'DELETE', 'PATCH'] + assert isinstance(scope['method'], str) + + assert scope['scheme'] in ['http', 'https'] + assert isinstance(scope['scheme'], str) + + assert isinstance(scope['path'], str) + assert isinstance(scope['raw_path'], bytes) + assert isinstance(scope['query_string'], bytes) + assert isinstance(scope['root_path'], str) + + assert hasattr(scope['headers'], '__iter__') + for k, v in scope['headers']: + assert isinstance(k, bytes) + assert isinstance(v, bytes) + + assert scope['client'] is None or hasattr(scope['client'], + '__iter__') + if scope['client']: + assert len(scope['client']) == 2 + assert isinstance(scope['client'][0], str) + assert isinstance(scope['client'][1], int) + + assert scope['server'] is None or hasattr(scope['server'], + '__iter__') + if scope['server']: + assert len(scope['server']) == 2 + assert isinstance(scope['server'][0], str) + assert isinstance(scope['server'][1], int) + + self.received_request = await receive() + assert self.received_request['type'] == 'http.request' + assert isinstance(self.received_request['body'], bytes) + assert isinstance(self.received_request['more_body'], bool) + + await send( + { + "type": "http.response.start", + "status": self.response_code, + "headers": self.response_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": self.response_body, + } + ) + + self.next_request = await receive() + assert self.next_request['type'] == 'http.disconnect' + else: + raise AssertionError(f"unexpected type {scope['type']}") class TestHttpAsgiMiddleware(unittest.TestCase): @@ -221,3 +260,45 @@ async def main(req, context): # Verify asserted self.assertEqual(response.status_code, 200) self.assertEqual(response.get_body(), test_body) + + def test_function_app_lifecycle_events(self): + mock_app = MockAsgiApplication() + middleware = AsgiMiddleware(mock_app) + asyncio.get_event_loop().run_until_complete( + middleware.notify_startup() + ) + assert mock_app.startup_called + + asyncio.get_event_loop().run_until_complete( + middleware.notify_shutdown() + ) + assert mock_app.shutdown_called + + def test_function_app_lifecycle_events_with_failures(self): + apps = [ + MockAsgiApplication(False, True), + MockAsgiApplication(True, False), + MockAsgiApplication(True, True), + MockAsgiApplication("bork", False), + MockAsgiApplication(False, "bork"), + MockAsgiApplication("bork", "bork"), + ] + for mock_app in apps: + middleware = AsgiMiddleware(mock_app) + asyncio.get_event_loop().run_until_complete( + middleware.notify_startup() + ) + assert mock_app.startup_called + + asyncio.get_event_loop().run_until_complete( + middleware.notify_shutdown() + ) + assert mock_app.shutdown_called + + def test_calling_shutdown_without_startup_errors(self): + mock_app = MockAsgiApplication() + middleware = AsgiMiddleware(mock_app) + with pytest.raises(RuntimeError): + asyncio.get_event_loop().run_until_complete( + middleware.notify_shutdown() + )