diff --git a/azure/functions/__init__.py b/azure/functions/__init__.py index 0998ee7b..f191e540 100644 --- a/azure/functions/__init__.py +++ b/azure/functions/__init__.py @@ -4,6 +4,7 @@ from ._cosmosdb import Document, DocumentList # NoQA from ._http import HttpRequest # NoQA from ._http import HttpResponse # NoQA +from ._http_wsgi import WsgiMiddleware # NoQA from ._queue import QueueMessage # NoQA from ._servicebus import ServiceBusMessage # NoQA from .meta import get_binding_registry # NoQA @@ -34,6 +35,7 @@ 'EventHubEvent', 'HttpRequest', 'HttpResponse', + 'WsgiMiddleware', 'InputStream', 'QueueMessage', 'ServiceBusMessage', diff --git a/azure/functions/_http_wsgi.py b/azure/functions/_http_wsgi.py new file mode 100644 index 00000000..2eede7ba --- /dev/null +++ b/azure/functions/_http_wsgi.py @@ -0,0 +1,174 @@ +from typing import Callable, Dict, List, Optional, Any +from io import BytesIO, StringIO +from os import linesep +from urllib.parse import urlparse +from wsgiref.headers import Headers + +from ._abc import Context +from ._http import HttpRequest, HttpResponse +from ._thirdparty.werkzeug._compat import string_types, wsgi_encoding_dance + + +class WsgiRequest: + _environ_cache: Dict[str, Any] = None + + def __init__(self, + func_req: HttpRequest, + func_ctx: Optional[Context] = None): + url = urlparse(func_req.url) + func_req_body = func_req.get_body() or b'' + + # Convert function request headers to lowercase header + self._lowercased_headers = { + k.lower(): v for k, v in func_req.headers.items() + } + + # Implement interfaces for PEP 3333 environ + self.request_method = getattr(func_req, 'method', None) + self.script_name = '' + self.path_info = getattr(url, 'path', None) + self.query_string = getattr(url, 'query', None) + self.content_type = self._lowercased_headers.get('content-type') + self.content_length = str(len(func_req_body)) + self.server_name = getattr(url, 'hostname', None) + self.server_port = str(self._get_port(url, self._lowercased_headers)) + self.server_protocol = 'HTTP/1.1' + + # Propagate http request headers into HTTP_ environ + self._http_environ: Dict[str, str] = self._get_http_headers( + func_req.headers + ) + + # Wsgi environ + self.wsgi_version = (1, 0) + self.wsgi_url_scheme = url.scheme + self.wsgi_input = BytesIO(func_req_body) + self.wsgi_multithread = False + self.wsgi_multiprocess = False + self.wsgi_run_once = False + + # Azure Functions context + self.af_function_directory = getattr(func_ctx, + 'function_directory', None) + self.af_function_name = getattr(func_ctx, 'function_name', None) + self.af_invocation_id = getattr(func_ctx, 'invocation_id', None) + + def to_environ(self, errors_buffer: StringIO) -> Dict[str, Any]: + if self._environ_cache is not None: + return self._environ_cache + + environ = { + 'REQUEST_METHOD': self.request_method, + 'SCRIPT_NAME': self.script_name, + 'PATH_INFO': self.path_info, + 'QUERY_STRING': self.query_string, + 'CONTENT_TYPE': self.content_type, + 'CONTENT_LENGTH': self.content_length, + 'SERVER_NAME': self.server_name, + 'SERVER_PORT': self.server_port, + 'SERVER_PROTOCOL': self.server_protocol, + 'wsgi.version': self.wsgi_version, + 'wsgi.url_scheme': self.wsgi_url_scheme, + 'wsgi.input': self.wsgi_input, + 'wsgi.errors': errors_buffer, + 'wsgi.multithread': self.wsgi_multithread, + 'wsgi.multiprocess': self.wsgi_multiprocess, + 'wsgi.run_once': self.wsgi_run_once, + 'azure_functions.function_directory': self.af_function_directory, + 'azure_functions.function_name': self.af_function_name, + 'azure_functions.invocation_id': self.af_invocation_id + } + environ.update(self._http_environ) + + # Ensure WSGI string fits in IOS-8859-1 code points + for k, v in environ.items(): + if isinstance(v, string_types): + environ[k] = wsgi_encoding_dance(v) + + # Remove None values + self._environ_cache = { + k: v for k, v in environ.items() if v is not None + } + return self._environ_cache + + def _get_port(self, parsed_url, lowercased_headers: Dict[str, str]) -> int: + port: int = 80 + if lowercased_headers.get('x-forwarded-port'): + return int(lowercased_headers['x-forwarded-port']) + elif getattr(parsed_url, 'port', None): + return parsed_url.port + elif parsed_url.scheme == 'https': + return 443 + return port + + def _get_http_headers(self, + func_headers: Dict[str, str]) -> Dict[str, str]: + # Content-Type -> HTTP_CONTENT_TYPE + return {f'HTTP_{k.upper().replace("-", "_")}': v for k, v in + func_headers.items()} + + +class WsgiResponse: + def __init__(self): + self._status = '' + self._status_code = 0 + self._headers = {} + self._buffer: List[bytes] = [] + + @classmethod + def from_app(cls, app, environ) -> 'WsgiResponse': + res = cls() + res._buffer = [x or b'' for x in app(environ, res._start_response)] + return res + + def to_func_response(self) -> HttpResponse: + lowercased_headers = {k.lower(): v for k, v in self._headers.items()} + return HttpResponse( + body=b''.join(self._buffer), + status_code=self._status_code, + headers=self._headers, + mimetype=lowercased_headers.get('content-type'), + charset=lowercased_headers.get('content-encoding') + ) + + # PEP 3333 start response implementation + def _start_response(self, status: str, response_headers: List[Any]): + self._status = status + self._headers = Headers(response_headers) + self._status_code = int(self._status.split(' ')[0]) # 200 OK + + +class WsgiMiddleware: + def __init__(self, app): + self._app = app + self._wsgi_error_buffer = StringIO() + + # Usage + # main = func.WsgiMiddleware(app).main + @property + def main(self) -> Callable[[HttpRequest, Optional[Context]], HttpResponse]: + return self._handle + + # Usage + # return func.WsgiMiddlewawre(app).handle(req, context) + def handle(self, + req: HttpRequest, + context: Optional[Context] = None) -> HttpResponse: + return self._handle(req, context) + + def _handle(self, + req: HttpRequest, + context: Context) -> HttpResponse: + wsgi_request = WsgiRequest(req, context) + environ = wsgi_request.to_environ(self._wsgi_error_buffer) + wsgi_response = WsgiResponse.from_app(self._app, environ) + self._handle_errors() + return wsgi_response.to_func_response() + + def _handle_errors(self): + if self._wsgi_error_buffer.tell() > 0: + self._wsgi_error_buffer.seek(0) + error_message = linesep.join( + self._wsgi_error_buffer.readline() + ) + raise Exception(error_message) diff --git a/tests/test_http_wsgi.py b/tests/test_http_wsgi.py new file mode 100644 index 00000000..c340e50f --- /dev/null +++ b/tests/test_http_wsgi.py @@ -0,0 +1,240 @@ +import unittest +from io import StringIO, BytesIO + +import azure.functions as func +from azure.functions._http_wsgi import ( + WsgiRequest, + WsgiResponse, + WsgiMiddleware +) + + +class WsgiException(Exception): + def __init__(self, message=''): + self.message = message + + +class TestHttpWsgi(unittest.TestCase): + + def test_request_general_environ_conversion(self): + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['REQUEST_METHOD'], 'POST') + self.assertEqual(environ['SCRIPT_NAME'], '') + self.assertEqual(environ['PATH_INFO'], '/api/http') + self.assertEqual(environ['QUERY_STRING'], 'firstname=rt') + self.assertEqual(environ['CONTENT_TYPE'], 'application/json') + self.assertEqual(environ['CONTENT_LENGTH'], + str(len(b'{ "lastname": "tsang" }'))) + self.assertEqual(environ['SERVER_NAME'], 'function.azurewebsites.net') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_request_wsgi_environ_conversion(self): + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['wsgi.version'], (1, 0)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + self.assertIsInstance(environ['wsgi.input'], BytesIO) + bytes_io: BytesIO = environ['wsgi.input'] + bytes_io.seek(0) + self.assertEqual(bytes_io.read(), b'{ "lastname": "tsang" }') + + self.assertIsInstance(environ['wsgi.errors'], StringIO) + string_io: StringIO = environ['wsgi.errors'] + string_io.seek(0) + self.assertEqual(string_io.read(), '') + + self.assertEqual(environ['wsgi.multithread'], False) + self.assertEqual(environ['wsgi.multiprocess'], False) + self.assertEqual(environ['wsgi.run_once'], False) + + def test_request_http_environ_conversion(self): + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['HTTP_X_MS_SITE_RESTRICTED_TOKEN'], 'xmsrt') + + def test_request_has_no_query_param(self): + func_request = self._generate_func_request( + url="https://function.azurewebsites.net", + params=None) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['QUERY_STRING'], '') + + def test_request_has_no_body(self): + func_request = self._generate_func_request(body=None) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['CONTENT_LENGTH'], str(0)) + + self.assertIsInstance(environ['wsgi.input'], BytesIO) + bytes_io: BytesIO = environ['wsgi.input'] + bytes_io.seek(0) + self.assertEqual(bytes_io.read(), b'') + + def test_request_has_no_headers(self): + func_request = self._generate_func_request(headers=None) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertNotIn('CONTENT_TYPE', environ) + + def test_request_protocol_by_header(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": "8081" + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(8081)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + def test_request_protocol_by_scheme(self): + func_request = self._generate_func_request(url="http://a.b.com") + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(80)) + self.assertEqual(environ['wsgi.url_scheme'], 'http') + + def test_request_parse_function_context(self): + func_request = self._generate_func_request() + func_context = self._generate_func_context() + error_buffer = StringIO() + environ = WsgiRequest(func_request, + func_context).to_environ(error_buffer) + self.assertEqual(environ['azure_functions.invocation_id'], + '123e4567-e89b-12d3-a456-426655440000') + self.assertEqual(environ['azure_functions.function_name'], + 'httptrigger') + self.assertEqual(environ['azure_functions.function_directory'], + '/home/roger/wwwroot/httptrigger') + + def test_response_from_wsgi_app(self): + app = self._generate_wsgi_app() + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + + wsgi_response: WsgiResponse = WsgiResponse.from_app(app, environ) + func_response: func.HttpResponse = wsgi_response.to_func_response() + + self.assertEqual(func_response.mimetype, 'text/plain') + self.assertEqual(func_response.charset, 'utf-8') + self.assertEqual(func_response.headers['Content-Type'], 'text/plain') + self.assertEqual(func_response.status_code, 200) + self.assertEqual(func_response.get_body(), b'sample string') + + def test_response_no_body(self): + app = self._generate_wsgi_app(response_body=None) + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + + wsgi_response: WsgiResponse = WsgiResponse.from_app(app, environ) + func_response: func.HttpResponse = wsgi_response.to_func_response() + self.assertEqual(func_response.get_body(), b'') + + def test_response_no_headers(self): + app = self._generate_wsgi_app(response_headers=None) + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + + wsgi_response: WsgiResponse = WsgiResponse.from_app(app, environ) + func_response: func.HttpResponse = wsgi_response.to_func_response() + self.assertEqual(func_response.headers, {}) + + def test_response_with_exception(self): + app = self._generate_wsgi_app( + exception=WsgiException(message='wsgi excpt')) + func_request = self._generate_func_request() + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + + with self.assertRaises(WsgiException) as e: + wsgi_response = WsgiResponse.from_app(app, environ) + wsgi_response.to_func_response() + + self.assertEqual(e.exception.message, 'wsgi excpt') + + def test_middleware_handle(self): + app = self._generate_wsgi_app() + func_request = self._generate_func_request() + func_response = WsgiMiddleware(app).handle(func_request) + self.assertEqual(func_response.status_code, 200) + + def _generate_func_request( + self, + method="POST", + url="https://function.azurewebsites.net/api/http?firstname=rt", + headers={ + "Content-Type": "application/json", + "x-ms-site-restricted-token": "xmsrt" + }, + params={ + "firstname": "roger" + }, + route_params={}, + body=b'{ "lastname": "tsang" }' + ) -> func.HttpRequest: + return func.HttpRequest( + method=method, + url=url, + headers=headers, + params=params, + route_params=route_params, + body=body + ) + + def _generate_func_context( + self, + invocation_id='123e4567-e89b-12d3-a456-426655440000', + function_name='httptrigger', + function_directory='/home/roger/wwwroot/httptrigger' + ) -> func.Context: + class MockContext(func.Context): + def __init__(self, ii, fn, fd): + self._invocation_id = ii + self._function_name = fn + self._function_directory = fd + + @property + def invocation_id(self): + return self._invocation_id + + @property + def function_name(self): + return self._function_name + + @property + def function_directory(self): + return self._function_directory + + return MockContext(invocation_id, function_name, function_directory) + + def _generate_wsgi_app(self, + status='200 OK', + response_headers=[('Content-Type', 'text/plain')], + response_body=b'sample string', + exception: WsgiException = None): + class MockWsgiApp: + _status = status + _response_headers = response_headers + _response_body = response_body + _exception = exception + + def __init__(self, environ, start_response): + self._environ = environ + self._start_response = start_response + + def __iter__(self): + if self._exception is not None: + raise self._exception + + self._start_response(self._status, self._response_headers) + yield self._response_body + + return MockWsgiApp