diff --git a/.github/workflows/ci_consumption_workflow.yml b/.github/workflows/ci_consumption_workflow.yml index 907e3de4c..d4580c7f1 100644 --- a/.github/workflows/ci_consumption_workflow.yml +++ b/.github/workflows/ci_consumption_workflow.yml @@ -12,7 +12,6 @@ on: push: branches: [ dev, main, release/* ] pull_request: - branches: [ dev, main, release/* ] jobs: build: @@ -30,10 +29,26 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Get Date + id: get-date + run: | + echo "todayDate=$(/bin/date -u "+%Y%m%d")" >> $GITHUB_ENV + shell: bash + - uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ env.todayDate }}-${{ matrix.python-version }} - name: Install dependencies + if: steps.cache-pip.outputs.cache-hit != 'true' run: | python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple -U azure-functions --pre python -m pip install -U -e .[dev] + if [[ "${{ matrix.python-version }}" != "3.7" ]]; then + python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple --pre -U -e .[test-http-v2] + fi + - name: Install worker + run: | python setup.py build - name: Running 3.7 Tests if: matrix.python-version == 3.7 diff --git a/.github/workflows/ci_e2e_workflow.yml b/.github/workflows/ci_e2e_workflow.yml index 11dc7a4d9..bbb20c88c 100644 --- a/.github/workflows/ci_e2e_workflow.yml +++ b/.github/workflows/ci_e2e_workflow.yml @@ -13,7 +13,6 @@ on: push: branches: [dev, main, release/*] pull_request: - branches: [dev, main, release/*] schedule: # Monday to Thursday 3 AM CST build # * is a special character in YAML so you have to quote this string @@ -39,7 +38,27 @@ jobs: uses: actions/setup-dotnet@v4 with: dotnet-version: "8.0.x" - - name: Install dependencies and the worker + - name: Get Date + id: get-date + run: | + echo "todayDate=$(/bin/date -u "+%Y%m%d")" >> $GITHUB_ENV + shell: bash + - uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ env.todayDate }}-${{ matrix.python-version }} + - name: Install dependencies + if: steps.cache-pip.outputs.cache-hit != 'true' + run: | + python -m pip install --upgrade pip + python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple -U azure-functions --pre + python -m pip install -U -e .[dev] + # Conditionally install test dependencies for Python 3.8 and later + if [[ "${{ matrix.python-version }}" != "3.7" ]]; then + python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple --pre -U -e .[test-http-v2] + fi + - name: Install worker run: | retry() { local -r -i max_attempts="$1"; shift @@ -58,10 +77,6 @@ jobs: done } - python -m pip install --upgrade pip - python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple -U azure-functions --pre - python -m pip install -U -e .[dev] - # Retry a couple times to avoid certificate issue retry 5 python setup.py build retry 5 python setup.py webhost --branch-name=dev diff --git a/.github/workflows/ci_ut_workflow.yml b/.github/workflows/ci_ut_workflow.yml index 002cac245..78e93eea4 100644 --- a/.github/workflows/ci_ut_workflow.yml +++ b/.github/workflows/ci_ut_workflow.yml @@ -15,8 +15,8 @@ on: # * is a special character in YAML so you have to quote this string - cron: "0 8 * * 1,2,3,4" push: - pull_request: branches: [ dev, main, release/* ] + pull_request: jobs: build: @@ -37,7 +37,27 @@ jobs: uses: actions/setup-dotnet@v4 with: dotnet-version: "8.0.x" - - name: Install dependencies and the worker + - name: Get Date + id: get-date + run: | + echo "todayDate=$(/bin/date -u "+%Y%m%d")" >> $GITHUB_ENV + shell: bash + - uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ env.todayDate }}-${{ matrix.python-version }} + - name: Install dependencies + if: steps.cache-pip.outputs.cache-hit != 'true' + run: | + python -m pip install --upgrade pip + python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple -U azure-functions --pre + python -m pip install -U -e .[dev] + # Conditionally install test dependencies for Python 3.8 and later + if [[ "${{ matrix.python-version }}" != "3.7" ]]; then + python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple --pre -U -e .[test-http-v2] + fi + - name: Install the worker run: | retry() { local -r -i max_attempts="$1"; shift @@ -55,11 +75,7 @@ jobs: fi done } - - python -m pip install --upgrade pip - python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple -U azure-functions --pre - python -m pip install -U -e .[dev] - + # Retry a couple times to avoid certificate issue retry 5 python setup.py build retry 5 python setup.py webhost --branch-name=dev @@ -70,14 +86,15 @@ jobs: AzureWebJobsStorage: ${{ secrets.LinuxStorageConnectionString310 }} # needed for installing azure-functions-durable while running setup.py ARCHIVE_WEBHOST_LOGS: ${{ github.event.inputs.archive_webhost_logging }} run: | - python -m pytest -q -n auto --dist loadfile --reruns 4 --instafail --cov=./azure_functions_worker --cov-report xml --cov-branch tests/unittests + python -m pytest -q -n auto --replay-record-dir=build/tests/replay --dist loadfile --reruns 4 --instafail --cov=./azure_functions_worker --cov-report xml --cov-branch tests/unittests - name: Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: ./coverage.xml # optional flags: unittests # optional name: codecov # optional fail_ci_if_error: false # optional (default = false) + token: ${{ secrets.CODECOV_TOKEN }} - name: Publish Logs to Artifact if: failure() uses: actions/upload-artifact@v4 @@ -85,3 +102,10 @@ jobs: name: Test WebHost Logs ${{ github.run_id }} ${{ matrix.python-version }} path: logs/*.log if-no-files-found: ignore + - name: Publish replays to Artifact + if: failure() + uses: actions/upload-artifact@v4 + with: + name: Test Replays ${{ github.run_id }} ${{ matrix.python-version }} + path: build/tests/replay + if-no-files-found: ignore diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 83e6f572f..d0923a8d5 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -15,8 +15,11 @@ name: Lint Code Base ############################# # Start the job on all push # ############################# -on: [ push, pull_request, workflow_dispatch ] - +on: + workflow_dispatch: + push: + branches: [ dev, main, release/* ] + pull_request: ############### # Set the Job # ############### diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index f7a810145..b7daf5666 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -3,11 +3,15 @@ import sys import typing -from .. import protos +from azure_functions_worker.constants import HTTP, HTTP_TRIGGER +from .. import protos from . import datumdef from . import generic from .shared_memory_data_transfer import SharedMemoryManager +from ..constants import BASE_EXT_SUPPORTED_PY_MINOR_VERSION, \ + PYTHON_ENABLE_INIT_INDEXING +from ..utils.common import is_envvar_true PB_TYPE = 'rpc_data' PB_TYPE_DATA = 'data' @@ -15,6 +19,37 @@ BINDING_REGISTRY = None +def _check_http_input_type_annotation(bind_name: str, pytype: type) -> bool: + if sys.version_info.minor >= BASE_EXT_SUPPORTED_PY_MINOR_VERSION and \ + is_envvar_true(PYTHON_ENABLE_INIT_INDEXING): + import azure.functions.extension.base as ext_base + if ext_base.HttpV2FeatureChecker.http_v2_enabled(): + return ext_base.RequestTrackerMeta.check_type(pytype) + + binding = get_binding(bind_name) + return binding.check_input_type_annotation(pytype) + + +def _check_http_output_type_annotation(bind_name: str, pytype: type) -> bool: + if sys.version_info.minor >= BASE_EXT_SUPPORTED_PY_MINOR_VERSION and \ + is_envvar_true(PYTHON_ENABLE_INIT_INDEXING): + import azure.functions.extension.base as ext_base + if ext_base.HttpV2FeatureChecker.http_v2_enabled(): + return ext_base.ResponseTrackerMeta.check_type(pytype) + + binding = get_binding(bind_name) + return binding.check_output_type_annotation(pytype) + + +INPUT_TYPE_CHECK_OVERRIDE_MAP = { + HTTP_TRIGGER: _check_http_input_type_annotation +} + +OUTPUT_TYPE_CHECK_OVERRIDE_MAP = { + HTTP: _check_http_output_type_annotation +} + + def load_binding_registry() -> None: func = sys.modules.get('azure.functions') @@ -43,11 +78,19 @@ def is_trigger_binding(bind_name: str) -> bool: def check_input_type_annotation(bind_name: str, pytype: type) -> bool: + global INPUT_TYPE_CHECK_OVERRIDE_MAP + if bind_name in INPUT_TYPE_CHECK_OVERRIDE_MAP: + return INPUT_TYPE_CHECK_OVERRIDE_MAP[bind_name](bind_name, pytype) + binding = get_binding(bind_name) return binding.check_input_type_annotation(pytype) def check_output_type_annotation(bind_name: str, pytype: type) -> bool: + global OUTPUT_TYPE_CHECK_OVERRIDE_MAP + if bind_name in OUTPUT_TYPE_CHECK_OVERRIDE_MAP: + return OUTPUT_TYPE_CHECK_OVERRIDE_MAP[bind_name](bind_name, pytype) + binding = get_binding(bind_name) return binding.check_output_type_annotation(pytype) diff --git a/azure_functions_worker/constants.py b/azure_functions_worker/constants.py index b6cc668b6..eea0193ec 100644 --- a/azure_functions_worker/constants.py +++ b/azure_functions_worker/constants.py @@ -10,6 +10,7 @@ WORKER_STATUS = "WorkerStatus" SHARED_MEMORY_DATA_TRANSFER = "SharedMemoryDataTransfer" FUNCTION_DATA_CACHE = "FunctionDataCache" +HTTP_URI = "HttpUri" # Platform Environment Variables AZURE_WEBJOBS_SCRIPT_ROOT = "AzureWebJobsScriptRoot" @@ -54,9 +55,25 @@ RETRY_POLICY = "retry_policy" # Paths -CUSTOMER_PACKAGES_PATH = "/home/site/wwwroot/.python_packages/lib/site-packages" +CUSTOMER_PACKAGES_PATH = "/home/site/wwwroot/.python_packages/lib/site" \ + "-packages" # Flag to index functions in handle init request PYTHON_ENABLE_INIT_INDEXING = "PYTHON_ENABLE_INIT_INDEXING" METADATA_PROPERTIES_WORKER_INDEXED = "worker_indexed" + +# HostNames +LOCAL_HOST = "127.0.0.1" + +# Header names +X_MS_INVOCATION_ID = "x-ms-invocation-id" + +# Trigger Names +HTTP_TRIGGER = "httpTrigger" + +# Output Names +HTTP = "http" + +# Base extension supported Python minor version +BASE_EXT_SUPPORTED_PY_MINOR_VERSION = 8 diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 2f94e21ba..161a74284 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -19,10 +19,9 @@ from datetime import datetime import grpc - from . import bindings, constants, functions, loader, protos from .bindings.shared_memory_data_transfer import SharedMemoryManager -from .constants import (PYTHON_ROLLBACK_CWD_PATH, +from .constants import (HTTP_TRIGGER, PYTHON_ROLLBACK_CWD_PATH, PYTHON_THREADPOOL_THREAD_COUNT, PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT, PYTHON_THREADPOOL_THREAD_COUNT_MAX_37, @@ -31,8 +30,10 @@ PYTHON_SCRIPT_FILE_NAME, PYTHON_SCRIPT_FILE_NAME_DEFAULT, PYTHON_LANGUAGE_RUNTIME, PYTHON_ENABLE_INIT_INDEXING, - METADATA_PROPERTIES_WORKER_INDEXED) + METADATA_PROPERTIES_WORKER_INDEXED, + BASE_EXT_SUPPORTED_PY_MINOR_VERSION) from .extension import ExtensionManager +from .http_v2 import http_coordinator, initialize_http_server from .logging import disable_console_logging, enable_console_logging from .logging import (logger, error_logger, is_system_log_category, CONSOLE_LOG_PREFIX, format_exception) @@ -74,6 +75,7 @@ def __init__(self, loop: BaseEventLoop, host: str, port: int, self._functions = functions.Registry() self._shmem_mgr = SharedMemoryManager() self._old_task_factory = None + self._has_http_func = False # Used to store metadata returns self._function_metadata_result = None @@ -111,7 +113,8 @@ def get_sync_tp_workers_set(self): 3.9 scenarios (as we'll start passing only None by default), and we need to get that information. - Ref: concurrent.futures.thread.ThreadPoolExecutor.__init__._max_workers + Ref: concurrent.futures.thread.ThreadPoolExecutor.__init__ + ._max_workers """ return self._sync_call_tp._max_workers @@ -158,6 +161,7 @@ async def dispatch_forever(self): # sourcery skip: swap-if-expression log_level = logging.INFO if not is_envvar_true( PYTHON_ENABLE_DEBUG_LOGGING) else logging.DEBUG + root_logger.setLevel(log_level) root_logger.addHandler(logging_handler) logger.info('Switched to gRPC logging.') @@ -189,7 +193,8 @@ def stop(self) -> None: self._stop_sync_call_tp() - def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None: + def on_logging(self, record: logging.LogRecord, + formatted_msg: str) -> None: if record.levelno >= logging.CRITICAL: log_level = protos.RpcLog.Critical elif record.levelno >= logging.ERROR: @@ -309,13 +314,24 @@ async def _handle__worker_init_request(self, request): except Exception as ex: self._function_metadata_exception = ex + if sys.version_info.minor >= BASE_EXT_SUPPORTED_PY_MINOR_VERSION \ + and self._has_http_func: + from azure.functions.extension.base \ + import HttpV2FeatureChecker + + if HttpV2FeatureChecker.http_v2_enabled(): + capabilities[constants.HTTP_URI] = \ + initialize_http_server() + return protos.StreamingMessage( request_id=self.request_id, worker_init_response=protos.WorkerInitResponse( capabilities=capabilities, worker_metadata=self.get_worker_metadata(), result=protos.StatusResult( - status=protos.StatusResult.Success))) + status=protos.StatusResult.Success), + ), + ) async def _handle__worker_status_request(self, request): # Logging is not necessary in this request since the response is used @@ -477,6 +493,7 @@ async def _handle__function_load_request(self, request): async def _handle__invocation_request(self, request): invocation_time = datetime.utcnow() invoc_request = request.invocation_request + trigger_metadata = invoc_request.trigger_metadata invocation_id = invoc_request.invocation_id function_id = invoc_request.function_id @@ -508,12 +525,12 @@ async def _handle__invocation_request(self, request): logger.info(', '.join(function_invocation_logs)) args = {} + for pb in invoc_request.input_data: pb_type_info = fi.input_types[pb.name] + trigger_metadata = None if bindings.is_trigger_binding(pb_type_info.binding_name): trigger_metadata = invoc_request.trigger_metadata - else: - trigger_metadata = None args[pb.name] = bindings.from_incoming_proto( pb_type_info.binding_name, pb, @@ -521,7 +538,29 @@ async def _handle__invocation_request(self, request): pytype=pb_type_info.pytype, shmem_mgr=self._shmem_mgr) - fi_context = self._get_context(invoc_request, fi.name, fi.directory) + http_v2_enabled = False + if sys.version_info.minor >= \ + BASE_EXT_SUPPORTED_PY_MINOR_VERSION \ + and fi.trigger_metadata is not None \ + and fi.trigger_metadata.get('type') == HTTP_TRIGGER: + from azure.functions.extension.base import HttpV2FeatureChecker + http_v2_enabled = HttpV2FeatureChecker.http_v2_enabled() + + if http_v2_enabled: + http_request = await http_coordinator.get_http_request_async( + invocation_id) + + from azure.functions.extension.base import RequestTrackerMeta + route_params = {key: item.string for key, item + in trigger_metadata.items() if key not in [ + 'Headers', 'Query']} + + (RequestTrackerMeta.get_synchronizer() + .sync_route_params(http_request, route_params)) + args[fi.trigger_metadata.get('param_name')] = http_request + + fi_context = self._get_context(invoc_request, fi.name, + fi.directory) # Use local thread storage to store the invocation ID # for a customer's threads @@ -533,18 +572,30 @@ async def _handle__invocation_request(self, request): for name in fi.output_types: args[name] = bindings.Out() - if fi.is_async: - call_result = await self._run_async_func( - fi_context, fi.func, args - ) - else: - call_result = await self._loop.run_in_executor( - self._sync_call_tp, - self._run_sync_func, - invocation_id, fi_context, fi.func, args) - if call_result is not None and not fi.has_return: - raise RuntimeError(f'function {fi.name!r} without a $return ' - 'binding returned a non-None value') + call_result = None + call_error = None + try: + if fi.is_async: + call_result = \ + await self._run_async_func(fi_context, fi.func, args) + else: + call_result = await self._loop.run_in_executor( + self._sync_call_tp, + self._run_sync_func, + invocation_id, fi_context, fi.func, args) + + if call_result is not None and not fi.has_return: + raise RuntimeError( + f'function {fi.name!r} without a $return binding' + 'returned a non-None value') + except Exception as e: + call_error = e + raise + finally: + if http_v2_enabled: + http_coordinator.set_http_response( + invocation_id, call_result + if call_result is not None else call_error) output_data = [] cache_enabled = self._function_data_cache_enabled @@ -564,10 +615,12 @@ async def _handle__invocation_request(self, request): output_data.append(param_binding) return_value = None - if fi.return_type is not None: + if fi.return_type is not None and not http_v2_enabled: return_value = bindings.to_outgoing_proto( - fi.return_type.binding_name, call_result, - pytype=fi.return_type.pytype) + fi.return_type.binding_name, + call_result, + pytype=fi.return_type.pytype, + ) # Actively flush customer print() function to console sys.stdout.flush() @@ -638,6 +691,7 @@ async def _handle__function_environment_reload_request(self, request): # reload_customer_libraries call clears the registry bindings.load_binding_registry() + capabilities = {} if is_envvar_true(PYTHON_ENABLE_INIT_INDEXING): try: self.load_function_metadata( @@ -646,6 +700,16 @@ async def _handle__function_environment_reload_request(self, request): except Exception as ex: self._function_metadata_exception = ex + if sys.version_info.minor >= \ + BASE_EXT_SUPPORTED_PY_MINOR_VERSION and \ + self._has_http_func: + from azure.functions.extension.base \ + import HttpV2FeatureChecker + + if HttpV2FeatureChecker.http_v2_enabled(): + capabilities[constants.HTTP_URI] = \ + initialize_http_server() + # Change function app directory if getattr(func_env_reload_request, 'function_app_directory', None): @@ -653,7 +717,7 @@ async def _handle__function_environment_reload_request(self, request): func_env_reload_request.function_app_directory) success_response = protos.FunctionEnvironmentReloadResponse( - capabilities={}, + capabilities=capabilities, worker_metadata=self.get_worker_metadata(), result=protos.StatusResult( status=protos.StatusResult.Success)) @@ -674,8 +738,10 @@ async def _handle__function_environment_reload_request(self, request): def index_functions(self, function_path: str): indexed_functions = loader.index_function_app(function_path) - logger.info('Indexed function app and found %s functions', - len(indexed_functions)) + logger.info( + "Indexed function app and found %s functions", + len(indexed_functions) + ) if indexed_functions: fx_metadata_results = loader.process_indexed_function( @@ -684,6 +750,8 @@ def index_functions(self, function_path: str): indexed_function_logs: List[str] = [] for func in indexed_functions: + self._has_http_func = self._has_http_func or \ + func.is_http_function() function_log = "Function Name: {}, Function Binding: {}" \ .format(func.get_function_name(), [(binding.type, binding.name) for binding in @@ -734,7 +802,8 @@ async def _handle__close_shared_memory_resources_request(self, request): @staticmethod def _get_context(invoc_request: protos.InvocationRequest, name: str, directory: str) -> bindings.Context: - """ For more information refer: https://aka.ms/azfunc-invocation-context + """ For more information refer: + https://aka.ms/azfunc-invocation-context """ trace_context = bindings.TraceContext( invoc_request.trace_context.trace_parent, @@ -876,7 +945,6 @@ def gen(resp_queue): class AsyncLoggingHandler(logging.Handler): - def emit(self, record: LogRecord) -> None: # Since we disable console log after gRPC channel is initiated, # we should redirect all the messages into dispatcher. diff --git a/azure_functions_worker/functions.py b/azure_functions_worker/functions.py index f0926230c..c5f00e040 100644 --- a/azure_functions_worker/functions.py +++ b/azure_functions_worker/functions.py @@ -6,6 +6,8 @@ import typing import uuid +from azure_functions_worker.constants import HTTP_TRIGGER + from . import bindings as bindings_utils from . import protos from ._thirdparty import typing_inspect @@ -31,6 +33,8 @@ class FunctionInfo(typing.NamedTuple): output_types: typing.Mapping[str, ParamTypeInfo] return_type: typing.Optional[ParamTypeInfo] + trigger_metadata: typing.Optional[typing.Dict[str, typing.Any]] + class FunctionLoadError(RuntimeError): @@ -297,6 +301,19 @@ def add_func_to_registry_and_return_funcinfo(self, function, str, ParamTypeInfo], return_type: str): + http_trigger_param_name = next( + (input_type for input_type, type_info in input_types.items() + if type_info.binding_name == HTTP_TRIGGER), + None + ) + + trigger_metadata = None + if http_trigger_param_name is not None: + trigger_metadata = { + "type": HTTP_TRIGGER, + "param_name": http_trigger_param_name + } + function_info = FunctionInfo( func=function, name=function_name, @@ -307,7 +324,8 @@ def add_func_to_registry_and_return_funcinfo(self, function, has_return=has_explicit_return or has_implicit_return, input_types=input_types, output_types=output_types, - return_type=return_type) + return_type=return_type, + trigger_metadata=trigger_metadata) self._functions[function_id] = function_info return function_info diff --git a/azure_functions_worker/http_v2.py b/azure_functions_worker/http_v2.py new file mode 100644 index 000000000..91758d0a2 --- /dev/null +++ b/azure_functions_worker/http_v2.py @@ -0,0 +1,210 @@ +import abc +import asyncio +import importlib +import socket +from typing import Dict + +from azure_functions_worker.constants import X_MS_INVOCATION_ID, LOCAL_HOST +from azure_functions_worker.logging import logger + + +class BaseContextReference(abc.ABC): + def __init__(self, event_class, http_request=None, http_response=None, + function=None, fi_context=None, args=None, + http_trigger_param_name=None): + self._http_request = http_request + self._http_response = http_response + self._function = function + self._fi_context = fi_context + self._args = args + self._http_trigger_param_name = http_trigger_param_name + self._http_request_available_event = event_class() + self._http_response_available_event = event_class() + + @property + def http_request(self): + return self._http_request + + @http_request.setter + def http_request(self, value): + self._http_request = value + self._http_request_available_event.set() + + @property + def http_response(self): + return self._http_response + + @http_response.setter + def http_response(self, value): + self._http_response = value + self._http_response_available_event.set() + + @property + def function(self): + return self._function + + @function.setter + def function(self, value): + self._function = value + + @property + def fi_context(self): + return self._fi_context + + @fi_context.setter + def fi_context(self, value): + self._fi_context = value + + @property + def http_trigger_param_name(self): + return self._http_trigger_param_name + + @http_trigger_param_name.setter + def http_trigger_param_name(self, value): + self._http_trigger_param_name = value + + @property + def args(self): + return self._args + + @args.setter + def args(self, value): + self._args = value + + @property + def http_request_available_event(self): + return self._http_request_available_event + + @property + def http_response_available_event(self): + return self._http_response_available_event + + +class AsyncContextReference(BaseContextReference): + def __init__(self, http_request=None, http_response=None, function=None, + fi_context=None, args=None): + super().__init__(event_class=asyncio.Event, http_request=http_request, + http_response=http_response, + function=function, fi_context=fi_context, args=args) + self.is_async = True + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class HttpCoordinator(metaclass=SingletonMeta): + def __init__(self): + self._context_references: Dict[str, BaseContextReference] = {} + + def set_http_request(self, invoc_id, http_request): + if invoc_id not in self._context_references: + self._context_references[invoc_id] = AsyncContextReference() + context_ref = self._context_references.get(invoc_id) + context_ref.http_request = http_request + + def set_http_response(self, invoc_id, http_response): + if invoc_id not in self._context_references: + raise Exception("No context reference found for invocation " + f"{invoc_id}") + context_ref = self._context_references.get(invoc_id) + context_ref.http_response = http_response + + async def get_http_request_async(self, invoc_id): + if invoc_id not in self._context_references: + self._context_references[invoc_id] = AsyncContextReference() + + await asyncio.sleep(0) + await self._context_references.get( + invoc_id).http_request_available_event.wait() + return self._pop_http_request(invoc_id) + + async def await_http_response_async(self, invoc_id): + if invoc_id not in self._context_references: + raise Exception("No context reference found for invocation " + f"{invoc_id}") + await asyncio.sleep(0) + await self._context_references.get( + invoc_id).http_response_available_event.wait() + return self._pop_http_response(invoc_id) + + def _pop_http_request(self, invoc_id): + context_ref = self._context_references.get(invoc_id) + request = context_ref.http_request + if request is not None: + context_ref.http_request = None + return request + + raise Exception(f"No http request found for invocation {invoc_id}") + + def _pop_http_response(self, invoc_id): + context_ref = self._context_references.get(invoc_id) + response = context_ref.http_response + if response is not None: + context_ref.http_response = None + return response + raise Exception(f"No http response found for invocation {invoc_id}") + + +def get_unused_tcp_port(): + # Create a TCP socket + tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Bind it to a free port provided by the OS + tcp_socket.bind(("", 0)) + # Get the port number + port = tcp_socket.getsockname()[1] + # Close the socket + tcp_socket.close() + # Return the port number + return port + + +def initialize_http_server(): + from azure.functions.extension.base \ + import ModuleTrackerMeta, RequestTrackerMeta + + web_extension_mod_name = ModuleTrackerMeta.get_module() + extension_module = importlib.import_module(web_extension_mod_name) + web_app_class = extension_module.WebApp + web_server_class = extension_module.WebServer + + unused_port = get_unused_tcp_port() + + app = web_app_class() + request_type = RequestTrackerMeta.get_request_type() + + @app.route + async def catch_all(request: request_type): # type: ignore + invoc_id = request.headers.get(X_MS_INVOCATION_ID) + if invoc_id is None: + raise ValueError(f"Header {X_MS_INVOCATION_ID} not found") + logger.info(f'Received HTTP request for invocation {invoc_id}') + http_coordinator.set_http_request(invoc_id, request) + http_resp = \ + await http_coordinator.await_http_response_async(invoc_id) + + logger.info(f'Sending HTTP response for invocation {invoc_id}') + # if http_resp is an python exception, raise it + if isinstance(http_resp, Exception): + raise http_resp + + return http_resp + + web_server = web_server_class(LOCAL_HOST, unused_port, app) + web_server_run_task = web_server.serve() + + loop = asyncio.get_event_loop() + loop.create_task(web_server_run_task) + + web_server_address = f"http://{LOCAL_HOST}:{unused_port}" + logger.info(f'HTTP server starting on {web_server_address}') + + return web_server_address + + +http_coordinator = HttpCoordinator() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..56e02198d --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + serial: mark test as a serial test diff --git a/setup.py b/setup.py index a3970fe19..bf60c456a 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,8 @@ ) else: INSTALL_REQUIRES.extend( - ("protobuf~=4.22.0", "grpcio-tools~=1.54.2", "grpcio~=1.54.2") + ("protobuf~=4.22.0", "grpcio-tools~=1.54.2", "grpcio~=1.54.2", + "azure-functions-extension-base") ) EXTRA_REQUIRES = { @@ -101,6 +102,7 @@ "pytest-randomly", "pytest-instafail", "pytest-rerunfailures", + "pytest-replay", "ptvsd", "python-dotenv", "plotly", @@ -109,7 +111,8 @@ "pandas", "numpy", "pre-commit" - ] + ], + "test-http-v2": ["azure-functions-extension-fastapi", "ujson", "orjson"] } diff --git a/tests/consumption_tests/test_linux_consumption.py b/tests/consumption_tests/test_linux_consumption.py index 195c1591c..9c5be090f 100644 --- a/tests/consumption_tests/test_linux_consumption.py +++ b/tests/consumption_tests/test_linux_consumption.py @@ -336,6 +336,44 @@ def test_reload_variables_after_oom_error(self): self.assertNotIn("Failure Exception: ModuleNotFoundError", logs) + @skipIf(sys.version_info.minor != 10, + "This is testing only for python310") + def test_http_v2_fastapi_streaming_upload_download(self): + """ + A function app with init indexing enabled + """ + with LinuxConsumptionWebHostController(_DEFAULT_HOST_VERSION, + self._py_version) as ctrl: + ctrl.assign_container(env={ + "AzureWebJobsStorage": self._storage, + "SCM_RUN_FROM_PACKAGE": + self._get_blob_url("HttpV2FastApiStreaming"), + PYTHON_ENABLE_INIT_INDEXING: "true", + PYTHON_ISOLATE_WORKER_DEPENDENCIES: "1" + }) + + def generate_random_bytes_stream(): + """Generate a stream of random bytes.""" + yield b'streaming' + yield b'testing' + yield b'response' + yield b'is' + yield b'returned' + + req = Request('POST', + f'{ctrl.url}/api/http_v2_fastapi_streaming', + data=generate_random_bytes_stream()) + resp = ctrl.send_request(req) + self.assertEqual(resp.status_code, 200) + + streamed_data = b'' + for chunk in resp.iter_content(chunk_size=1024): + if chunk: + streamed_data += chunk + + self.assertEqual( + streamed_data, b'streamingtestingresponseisreturned') + def _get_blob_url(self, scenario_name: str) -> str: return ( f'https://pythonworker{self._py_shortform}sa.blob.core.windows.net/' diff --git a/tests/endtoend/http_functions/http_functions_v2/fastapi/file_name/main.py b/tests/endtoend/http_functions/http_functions_v2/fastapi/file_name/main.py new file mode 100644 index 000000000..c9718fef5 --- /dev/null +++ b/tests/endtoend/http_functions/http_functions_v2/fastapi/file_name/main.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from datetime import datetime +import logging +import time + +import azure.functions as func + +from azure.functions.extension.fastapi import Request, Response + +app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS) + + +@app.route(route="default_template") +async def default_template(req: Request) -> Response: + logging.info('Python HTTP trigger function processed a request.') + + name = req.query_params.get('name') + if not name: + try: + req_body = await req.json() + except ValueError: + pass + else: + name = req_body.get('name') + + if name: + return Response( + f"Hello, {name}. This HTTP triggered function " + f"executed successfully.") + else: + return Response( + "This HTTP triggered function executed successfully. " + "Pass a name in the query string or in the request body for a" + " personalized response.", + status_code=200 + ) + + +@app.route(route="http_func") +def http_func(req: Request) -> Response: + time.sleep(1) + + current_time = datetime.now().strftime("%H:%M:%S") + return Response(f"{current_time}") diff --git a/tests/endtoend/http_functions/http_functions_v2/fastapi/function_app.py b/tests/endtoend/http_functions/http_functions_v2/fastapi/function_app.py new file mode 100644 index 000000000..b82e0baee --- /dev/null +++ b/tests/endtoend/http_functions/http_functions_v2/fastapi/function_app.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from datetime import datetime +import logging +import time + +import azure.functions as func +from azure.functions.extension.fastapi import Request, Response, \ + StreamingResponse, HTMLResponse, \ + UJSONResponse, ORJSONResponse, FileResponse + +app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS) + + +@app.route(route="default_template") +async def default_template(req: Request) -> Response: + logging.info('Python HTTP trigger function processed a request.') + + name = req.query_params.get('name') + if not name: + try: + req_body = await req.json() + except ValueError: + pass + else: + name = req_body.get('name') + + if name: + return Response( + f"Hello, {name}. This HTTP triggered function " + f"executed successfully.") + else: + return Response( + "This HTTP triggered function executed successfully. " + "Pass a name in the query string or in the request body for a" + " personalized response.", + status_code=200 + ) + + +@app.route(route="http_func") +def http_func(req: Request) -> Response: + time.sleep(1) + + current_time = datetime.now().strftime("%H:%M:%S") + return Response(f"{current_time}") + + +@app.route(route="upload_data_stream") +async def upload_data_stream(req: Request) -> Response: + # Define a list to accumulate the streaming data + data_chunks = [] + + async def process_stream(): + async for chunk in req.stream(): + # Append each chunk of streaming data to the list + data_chunks.append(chunk) + + await process_stream() + + # Concatenate the data chunks to form the complete data + complete_data = b"".join(data_chunks) + + # Return the complete data as the response + return Response(content=complete_data, status_code=200) + + +@app.route(route="return_streaming") +async def return_streaming(req: Request) -> StreamingResponse: + async def content(): + yield b"First chunk\n" + yield b"Second chunk\n" + return StreamingResponse(content()) + + +@app.route(route="return_html") +def return_html(req: Request) -> HTMLResponse: + html_content = "

Hello, World!

" + return HTMLResponse(content=html_content, status_code=200) + + +@app.route(route="return_ujson") +def return_ujson(req: Request) -> UJSONResponse: + return UJSONResponse(content={"message": "Hello, World!"}, status_code=200) + + +@app.route(route="return_orjson") +def return_orjson(req: Request) -> ORJSONResponse: + return ORJSONResponse(content={"message": "Hello, World!"}, status_code=200) + + +@app.route(route="return_file") +def return_file(req: Request) -> FileResponse: + return FileResponse("function_app.py") diff --git a/tests/endtoend/test_http_functions.py b/tests/endtoend/test_http_functions.py index 9d6f9fdf6..f3a8a6e07 100644 --- a/tests/endtoend/test_http_functions.py +++ b/tests/endtoend/test_http_functions.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import concurrent import os +import sys import typing +import unittest +from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch import requests @@ -220,18 +224,160 @@ def tearDownClass(cls): super().tearDownClass() -class TestHttpFunctionsV2WithInitIndexing(TestHttpFunctionsStein): - +@unittest.skipIf(sys.version_info.minor <= 7, "Skipping tests <= Python 3.7") +class TestHttpFunctionsV2FastApiWithInitIndexing( + TestHttpFunctionsWithInitIndexing): @classmethod - def setUpClass(cls): - os.environ[PYTHON_ENABLE_INIT_INDEXING] = "1" - super().setUpClass() + def get_script_dir(cls): + return testutils.E2E_TESTS_FOLDER / 'http_functions' / \ + 'http_functions_v2' / \ + 'fastapi' - @classmethod - def tearDownClass(cls): - # Remove the PYTHON_SCRIPT_FILE_NAME environment variable - os.environ.pop(PYTHON_ENABLE_INIT_INDEXING) - super().tearDownClass() + @testutils.retryable_test(3, 5) + def test_return_streaming(self): + """Test if the return_streaming function returns a streaming + response""" + root_url = self.webhost._addr + streaming_url = f'{root_url}/api/return_streaming' + r = requests.get( + streaming_url, timeout=REQUEST_TIMEOUT_SEC, stream=True) + self.assertTrue(r.ok) + # Validate streaming content + expected_content = [b"First chunk\n", b"Second chunk\n"] + received_content = [] + for chunk in r.iter_content(chunk_size=1024): + if chunk: + received_content.append(chunk) + self.assertEqual(received_content, expected_content) + + @testutils.retryable_test(3, 5) + def test_return_streaming_concurrently(self): + """Test if the return_streaming function returns a streaming + response concurrently""" + root_url = self.webhost._addr + streaming_url = f'{root_url}/return_streaming' + + # Function to make a streaming request and validate content + def make_request(): + r = requests.get(streaming_url, timeout=REQUEST_TIMEOUT_SEC, + stream=True) + self.assertTrue(r.ok) + expected_content = [b"First chunk\n", b"Second chunk\n"] + received_content = [] + for chunk in r.iter_content(chunk_size=1024): + if chunk: + received_content.append(chunk) + self.assertEqual(received_content, expected_content) + + # Make concurrent requests + with ThreadPoolExecutor(max_workers=2) as executor: + executor.map(make_request, range(2)) + + @testutils.retryable_test(3, 5) + def test_return_html(self): + """Test if the return_html function returns an HTML response""" + root_url = self.webhost._addr + html_url = f'{root_url}/api/return_html' + r = requests.get(html_url, timeout=REQUEST_TIMEOUT_SEC) + self.assertTrue(r.ok) + self.assertEqual(r.headers['content-type'], + 'text/html; charset=utf-8') + # Validate HTML content + expected_html = "

Hello, World!

" + self.assertEqual(r.text, expected_html) + + @testutils.retryable_test(3, 5) + def test_return_ujson(self): + """Test if the return_ujson function returns a UJSON response""" + root_url = self.webhost._addr + ujson_url = f'{root_url}/api/return_ujson' + r = requests.get(ujson_url, timeout=REQUEST_TIMEOUT_SEC) + self.assertTrue(r.ok) + self.assertEqual(r.headers['content-type'], 'application/json') + self.assertEqual(r.text, '{"message":"Hello, World!"}') + + @testutils.retryable_test(3, 5) + def test_return_orjson(self): + """Test if the return_orjson function returns an ORJSON response""" + root_url = self.webhost._addr + orjson_url = f'{root_url}/api/return_orjson' + r = requests.get(orjson_url, timeout=REQUEST_TIMEOUT_SEC) + self.assertTrue(r.ok) + self.assertEqual(r.headers['content-type'], 'application/json') + self.assertEqual(r.text, '{"message":"Hello, World!"}') + + @testutils.retryable_test(3, 5) + def test_return_file(self): + """Test if the return_file function returns a file response""" + root_url = self.webhost._addr + file_url = f'{root_url}/api/return_file' + r = requests.get(file_url, timeout=REQUEST_TIMEOUT_SEC) + self.assertTrue(r.ok) + self.assertIn('@app.route(route="default_template")', r.text) + + @testutils.retryable_test(3, 5) + def test_upload_data_stream(self): + """Test if the upload_data_stream function receives streaming data + and returns the complete data""" + root_url = self.webhost._addr + upload_url = f'{root_url}/api/upload_data_stream' + + # Define the streaming data + data_chunks = [b"First chunk\n", b"Second chunk\n"] + + # Define a function to simulate streaming by reading from an + # iterator + def stream_data(data_chunks): + for chunk in data_chunks: + yield chunk + + # Send a POST request with streaming data + r = requests.post(upload_url, data=stream_data(data_chunks)) + + # Assert that the request was successful + self.assertTrue(r.ok) + + # Assert that the response content matches the concatenation of + # all data chunks + complete_data = b"".join(data_chunks) + self.assertEqual(r.content, complete_data) + + @testutils.retryable_test(3, 5) + def test_upload_data_stream_concurrently(self): + """Test if the upload_data_stream function receives streaming data + and returns the complete data""" + root_url = self.webhost._addr + upload_url = f'{root_url}/api/upload_data_stream' + + # Define the streaming data + data_chunks = [b"First chunk\n", b"Second chunk\n"] + + # Define a function to simulate streaming by reading from an + # iterator + def stream_data(data_chunks): + for chunk in data_chunks: + yield chunk + + # Define the number of concurrent requests + num_requests = 5 + + # Define a function to send a single request + def send_request(): + r = requests.post(upload_url, data=stream_data(data_chunks)) + return r.ok, r.content + + # Send multiple requests concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(send_request) for _ in + range(num_requests)] + + # Assert that all requests were successful and the response + # contents are correct + for future in concurrent.futures.as_completed(futures): + ok, content = future.result() + self.assertTrue(ok) + complete_data = b"".join(data_chunks) + self.assertEqual(content, complete_data) class TestUserThreadLoggingHttpFunctions(testutils.WebHostTestCase): diff --git a/tests/unittests/dispatcher_functions/http_v2/fastapi/function_app.py b/tests/unittests/dispatcher_functions/http_v2/fastapi/function_app.py new file mode 100644 index 000000000..f202890de --- /dev/null +++ b/tests/unittests/dispatcher_functions/http_v2/fastapi/function_app.py @@ -0,0 +1,10 @@ +from azure.functions.extension.fastapi import Request, Response +import azure.functions as func + + +app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS) + + +@app.route(route="http_trigger") +def http_trigger(req: Request) -> Response: + return Response("ok") diff --git a/tests/unittests/http_functions/http_v2_functions/fastapi/function_app.py b/tests/unittests/http_functions/http_v2_functions/fastapi/function_app.py new file mode 100644 index 000000000..9830f572e --- /dev/null +++ b/tests/unittests/http_functions/http_v2_functions/fastapi/function_app.py @@ -0,0 +1,433 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import asyncio +import hashlib +import logging +import sys +import time +from urllib.request import urlopen +from azure.functions.extension.fastapi import Request, Response, \ + HTMLResponse, RedirectResponse +import azure.functions as func +from pydantic import BaseModel + +app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS) + +logger = logging.getLogger("my-function") + + +class Item(BaseModel): + name: str + description: str + + +@app.route(route="no_type_hint") +def no_type_hint(req): + return 'no_type_hint' + + +@app.route(route="return_int") +def return_int(req) -> int: + return 1000 + + +@app.route(route="return_float") +def return_float(req) -> float: + return 1000.0 + + +@app.route(route="return_bool") +def return_bool(req) -> bool: + return True + + +@app.route(route="return_dict") +def return_dict(req) -> dict: + return {"key": "value"} + + +@app.route(route="return_list") +def return_list(req): + return ["value1", "value2"] + + +@app.route(route="return_pydantic_model") +def return_pydantic_model(req) -> Item: + item = Item(name="item1", description="description1") + return item + + +@app.route(route="return_pydantic_model_with_missing_fields") +def return_pydantic_model_with_missing_fields(req) -> Item: + item = Item(name="item1") + return item + + +@app.route(route="accept_json") +async def accept_json(req: Request): + return await req.json() + + +async def nested(): + try: + 1 / 0 + except ZeroDivisionError: + logger.error('and another error', exc_info=True) + + +@app.route(route="async_logging") +async def async_logging(req: Request): + logger.info('hello %s', 'info') + + await asyncio.sleep(0.1) + + # Create a nested task to check if invocation_id is still + # logged correctly. + await asyncio.ensure_future(nested()) + + await asyncio.sleep(0.1) + + return 'OK-async' + + +@app.route(route="async_return_str") +async def async_return_str(req: Request): + await asyncio.sleep(0.1) + return 'Hello Async World!' + + +@app.route(route="debug_logging") +def debug_logging(req: Request): + logging.critical('logging critical', exc_info=True) + logging.info('logging info', exc_info=True) + logging.warning('logging warning', exc_info=True) + logging.debug('logging debug', exc_info=True) + logging.error('logging error', exc_info=True) + return 'OK-debug' + + +@app.route(route="debug_user_logging") +def debug_user_logging(req: Request): + logger.setLevel(logging.DEBUG) + + logging.critical('logging critical', exc_info=True) + logger.info('logging info', exc_info=True) + logger.warning('logging warning', exc_info=True) + logger.debug('logging debug', exc_info=True) + logger.error('logging error', exc_info=True) + return 'OK-user-debug' + + +# Attempt to log info into system log from customer code +disguised_logger = logging.getLogger('azure_functions_worker') + + +async def parallelly_print(): + await asyncio.sleep(0.1) + print('parallelly_print') + + +async def parallelly_log_info(): + await asyncio.sleep(0.2) + logging.info('parallelly_log_info at root logger') + + +async def parallelly_log_warning(): + await asyncio.sleep(0.3) + logging.warning('parallelly_log_warning at root logger') + + +async def parallelly_log_error(): + await asyncio.sleep(0.4) + logging.error('parallelly_log_error at root logger') + + +async def parallelly_log_exception(): + await asyncio.sleep(0.5) + try: + raise Exception('custom exception') + except Exception: + logging.exception('parallelly_log_exception at root logger', + exc_info=sys.exc_info()) + + +async def parallelly_log_custom(): + await asyncio.sleep(0.6) + logger.info('parallelly_log_custom at custom_logger') + + +async def parallelly_log_system(): + await asyncio.sleep(0.7) + disguised_logger.info('parallelly_log_system at disguised_logger') + + +@app.route(route="hijack_current_event_loop") +async def hijack_current_event_loop(req: Request) -> Response: + loop = asyncio.get_event_loop() + + # Create multiple tasks and schedule it into one asyncio.wait blocker + task_print: asyncio.Task = loop.create_task(parallelly_print()) + task_info: asyncio.Task = loop.create_task(parallelly_log_info()) + task_warning: asyncio.Task = loop.create_task(parallelly_log_warning()) + task_error: asyncio.Task = loop.create_task(parallelly_log_error()) + task_exception: asyncio.Task = loop.create_task(parallelly_log_exception()) + task_custom: asyncio.Task = loop.create_task(parallelly_log_custom()) + task_disguise: asyncio.Task = loop.create_task(parallelly_log_system()) + + # Create an awaitable future and occupy the current event loop resource + future = loop.create_future() + loop.call_soon_threadsafe(future.set_result, 'callsoon_log') + + # WaitAll + await asyncio.wait([task_print, task_info, task_warning, task_error, + task_exception, task_custom, task_disguise, future]) + + # Log asyncio low-level future result + logging.info(future.result()) + + return 'OK-hijack-current-event-loop' + + +@app.route(route="print_logging") +def print_logging(req: Request): + flush_required = False + is_console_log = False + is_stderr = False + message = req.query_params.get('message', '') + + if req.query_params.get('flush') == 'true': + flush_required = True + if req.query_params.get('console') == 'true': + is_console_log = True + if req.query_params.get('is_stderr') == 'true': + is_stderr = True + + # Adding LanguageWorkerConsoleLog will make function host to treat + # this as system log and will be propagated to kusto + prefix = 'LanguageWorkerConsoleLog' if is_console_log else '' + print(f'{prefix} {message}'.strip(), + file=sys.stderr if is_stderr else sys.stdout, + flush=flush_required) + + return 'OK-print-logging' + + +@app.route(route="raw_body_bytes") +async def raw_body_bytes(req: Request) -> Response: + body = await req.body() + body_len = str(len(body)) + + headers = {'body-len': body_len} + return Response(content=body, status_code=200, headers=headers) + + +@app.route(route="remapped_context") +def remapped_context(req: Request): + return req.method + + +@app.route(route="return_bytes") +def return_bytes(req: Request): + return b"Hello World" + + +@app.route(route="return_context") +def return_context(req: Request, context: func.Context): + return { + 'method': req.method, + 'ctx_func_name': context.function_name, + 'ctx_func_dir': context.function_directory, + 'ctx_invocation_id': context.invocation_id, + 'ctx_trace_context_Traceparent': context.trace_context.Traceparent, + 'ctx_trace_context_Tracestate': context.trace_context.Tracestate, + } + + +@app.route(route="return_http") +def return_http(req: Request) -> HTMLResponse: + return HTMLResponse('

Hello World™

') + + +@app.route(route="return_http_404") +def return_http_404(req: Request): + return Response('bye', status_code=404) + + +@app.route(route="return_http_auth_admin", auth_level=func.AuthLevel.ADMIN) +def return_http_auth_admin(req: Request) -> HTMLResponse: + return HTMLResponse('

Hello World™

') + + +@app.route(route="return_http_no_body") +def return_http_no_body(req: Request): + return Response() + + +@app.route(route="return_http_redirect") +def return_http_redirect(req: Request): + return RedirectResponse(url='/api/return_http', status_code=302) + + +@app.route(route="return_request") +async def return_request(req: Request): + params = dict(req.query_params) + params.pop('code', None) # Remove 'code' parameter if present + + # Get the body content and calculate its hash + body = await req.body() + body_hash = hashlib.sha256(body).hexdigest() if body else None + + # Return a dictionary containing request information + return { + 'method': req.method, + 'url': str(req.url), + 'headers': dict(req.headers), + 'params': params, + 'body': body.decode() if body else None, + 'body_hash': body_hash, + } + + +@app.route(route="return_route_params/{param1}/{param2}") +def return_route_params(req: Request) -> str: + # log type of req + logger.info(f"req type: {type(req)}") + # log req path params + logger.info(f"req path params: {req.path_params}") + return req.path_params + + +@app.route(route="sync_logging") +def main(req: Request): + try: + 1 / 0 + except ZeroDivisionError: + logger.error('a gracefully handled error', exc_info=True) + logger.error('a gracefully handled critical error', exc_info=True) + time.sleep(0.05) + return 'OK-sync' + + +@app.route(route="unhandled_error") +def unhandled_error(req: Request): + 1 / 0 + + +@app.route(route="unhandled_urllib_error") +def unhandled_urllib_error(req: Request) -> str: + image_url = req.params.get('img') + urlopen(image_url).read() + + +class UnserializableException(Exception): + def __str__(self): + raise RuntimeError('cannot serialize me') + + +@app.route(route="unhandled_unserializable_error") +def unhandled_unserializable_error(req: Request) -> str: + raise UnserializableException('foo') + + +async def try_log(): + logger.info("try_log") + + +@app.route(route="user_event_loop") +def user_event_loop(req: Request) -> Response: + loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(loop) + + # This line should throws an asyncio RuntimeError exception + loop.run_until_complete(try_log()) + loop.close() + return 'OK-user-event-loop' + + +@app.route(route="multiple_set_cookie_resp_headers") +async def multiple_set_cookie_resp_headers(req: Request): + logging.info('Python HTTP trigger function processed a request.') + resp = Response( + "This HTTP triggered function executed successfully.") + + expires_1 = "Thu, 12 Jan 2017 13:55:08 GMT" + expires_2 = "Fri, 12 Jan 2018 13:55:08 GMT" + + resp.set_cookie( + key='foo3', + value='42', + domain='example.com', + expires=expires_1, + path='/', + max_age=10000000, + secure=True, + httponly=True, + samesite='Lax' + ) + + resp.set_cookie( + key='foo3', + value='43', + domain='example.com', + expires=expires_2, + path='/', + max_age=10000000, + secure=True, + httponly=True, + samesite='Lax' + ) + + return resp + + +@app.route(route="response_cookie_header_nullable_bool_err") +def response_cookie_header_nullable_bool_err( + req: Request) -> Response: + logging.info('Python HTTP trigger function processed a request.') + resp = Response( + "This HTTP triggered function executed successfully.") + + # Set the cookie with Secure attribute set to False + resp.set_cookie( + key='foo3', + value='42', + domain='example.com', + expires='Thu, 12-Jan-2017 13:55:08 GMT', + path='/', + max_age=10000000, + secure=False, + httponly=True + ) + + return resp + + +@app.route(route="response_cookie_header_nullable_timestamp_err") +def response_cookie_header_nullable_timestamp_err( + req: Request) -> Response: + logging.info('Python HTTP trigger function processed a request.') + resp = Response( + "This HTTP triggered function executed successfully.") + + resp.set_cookie( + key='foo3', + value='42', + domain='example.com' + ) + + return resp + + +@app.route(route="set_cookie_resp_header_default_values") +def set_cookie_resp_header_default_values( + req: Request) -> Response: + logging.info('Python HTTP trigger function processed a request.') + resp = Response( + "This HTTP triggered function executed successfully.") + + resp.set_cookie( + key='foo3', + value='42' + ) + + return resp diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index 37f23ea5f..dc62acb8e 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -21,12 +21,17 @@ from tests.utils import testutils from tests.utils.testutils import UNIT_TESTS_ROOT + SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro", "releaselevel", "serial"]) DISPATCHER_FUNCTIONS_DIR = testutils.UNIT_TESTS_FOLDER / 'dispatcher_functions' DISPATCHER_STEIN_FUNCTIONS_DIR = testutils.UNIT_TESTS_FOLDER / \ 'dispatcher_functions' / \ 'dispatcher_functions_stein' +DISPATCHER_HTTP_V2_FASTAPI_FUNCTIONS_DIR = testutils.UNIT_TESTS_FOLDER / \ + 'dispatcher_functions' / \ + 'http_v2' / \ + 'fastapi' FUNCTION_APP_DIRECTORY = UNIT_TESTS_ROOT / 'dispatcher_functions' / \ 'dispatcher_functions_stein' @@ -67,10 +72,10 @@ async def test_dispatcher_initialize_worker(self): self.assertIsInstance(r.response, protos.WorkerInitResponse) self.assertIsInstance(r.response.worker_metadata, protos.WorkerMetadata) - self.assertEquals(r.response.worker_metadata.runtime_name, - "python") - self.assertEquals(r.response.worker_metadata.worker_version, - VERSION) + self.assertEqual(r.response.worker_metadata.runtime_name, + "python") + self.assertEqual(r.response.worker_metadata.worker_version, + VERSION) async def test_dispatcher_environment_reload(self): """Test function environment reload response @@ -82,10 +87,10 @@ async def test_dispatcher_environment_reload(self): protos.FunctionEnvironmentReloadResponse) self.assertIsInstance(r.response.worker_metadata, protos.WorkerMetadata) - self.assertEquals(r.response.worker_metadata.runtime_name, - "python") - self.assertEquals(r.response.worker_metadata.worker_version, - VERSION) + self.assertEqual(r.response.worker_metadata.runtime_name, + "python") + self.assertEqual(r.response.worker_metadata.worker_version, + VERSION) async def test_dispatcher_initialize_worker_logging(self): """Test if the dispatcher's log can be flushed out during worker @@ -590,16 +595,6 @@ class TestDispatcherStein(testutils.AsyncTestCase): def setUp(self): self._ctrl = testutils.start_mockhost( script_root=DISPATCHER_STEIN_FUNCTIONS_DIR) - self._pre_env = dict(os.environ) - self.mock_version_info = patch( - 'azure_functions_worker.dispatcher.sys.version_info', - SysVersionInfo(3, 9, 0, 'final', 0)) - self.mock_version_info.start() - - def tearDown(self): - os.environ.clear() - os.environ.update(self._pre_env) - self.mock_version_info.stop() async def test_dispatcher_functions_metadata_request(self): """Test if the functions metadata response will be sent correctly @@ -683,7 +678,8 @@ async def test_dispatcher_load_azfunc_in_init(self): ) self.assertEqual( len([log for log in r.logs if log.message.startswith( - "Received WorkerMetadataRequest from _handle__worker_init_request" + "Received WorkerMetadataRequest from " + "_handle__worker_init_request" )]), 0 ) @@ -776,7 +772,6 @@ def tearDown(self): @patch.dict(os.environ, {PYTHON_ENABLE_INIT_INDEXING: 'true'}) def test_worker_init_request_with_indexing_enabled(self): - request = protos.StreamingMessage( worker_init_request=protos.WorkerInitRequest( host_version="2.3.4", @@ -845,10 +840,12 @@ def test_functions_metadata_request_with_init_indexing_enabled(self): protos.StatusResult.Success) metadata_response = self.loop.run_until_complete( - self.dispatcher._handle__functions_metadata_request(metadata_request)) + self.dispatcher._handle__functions_metadata_request( + metadata_request)) - self.assertEqual(metadata_response.function_metadata_response.result.status, - protos.StatusResult.Success) + self.assertEqual( + metadata_response.function_metadata_response.result.status, + protos.StatusResult.Success) self.assertIsNotNone(self.dispatcher._function_metadata_result) self.assertIsNone(self.dispatcher._function_metadata_exception) @@ -875,10 +872,12 @@ def test_functions_metadata_request_with_init_indexing_disabled(self): self.assertIsNone(self.dispatcher._function_metadata_exception) metadata_response = self.loop.run_until_complete( - self.dispatcher._handle__functions_metadata_request(metadata_request)) + self.dispatcher._handle__functions_metadata_request( + metadata_request)) - self.assertEqual(metadata_response.function_metadata_response.result.status, - protos.StatusResult.Success) + self.assertEqual( + metadata_response.function_metadata_response.result.status, + protos.StatusResult.Success) self.assertIsNotNone(self.dispatcher._function_metadata_result) self.assertIsNone(self.dispatcher._function_metadata_exception) @@ -887,7 +886,6 @@ def test_functions_metadata_request_with_init_indexing_disabled(self): def test_functions_metadata_request_with_indexing_exception( self, mock_index_functions): - mock_index_functions.side_effect = Exception("Mocked Exception") request = protos.StreamingMessage( diff --git a/tests/unittests/test_enable_debug_logging_functions.py b/tests/unittests/test_enable_debug_logging_functions.py index 120d54dfe..6f3739809 100644 --- a/tests/unittests/test_enable_debug_logging_functions.py +++ b/tests/unittests/test_enable_debug_logging_functions.py @@ -65,6 +65,7 @@ class TestDebugLoggingDisabledFunctions(testutils.WebHostTestCase): """ @classmethod def setUpClass(cls): + cls._pre_env = dict(os.environ) os_environ = os.environ.copy() os_environ[PYTHON_ENABLE_DEBUG_LOGGING] = '0' cls._patch_environ = patch.dict('os.environ', os_environ) @@ -73,8 +74,9 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - os.environ.pop(PYTHON_ENABLE_DEBUG_LOGGING) super().tearDownClass() + os.environ.clear() + os.environ.update(cls._pre_env) cls._patch_environ.stop() @classmethod diff --git a/tests/unittests/test_http_functions_v2.py b/tests/unittests/test_http_functions_v2.py new file mode 100644 index 000000000..45428a6b7 --- /dev/null +++ b/tests/unittests/test_http_functions_v2.py @@ -0,0 +1,465 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import filecmp +import hashlib +import os +import pathlib +import sys +import typing +import unittest +from unittest import skipIf +from unittest.mock import patch + +from azure_functions_worker.constants import PYTHON_ENABLE_INIT_INDEXING +from tests.utils import testutils + + +@unittest.skipIf(sys.version_info.minor <= 7, "Skipping tests <= Python 3.7") +class TestHttpFunctionsV2FastApi(testutils.WebHostTestCase): + @classmethod + def setUpClass(cls): + cls._pre_env = dict(os.environ) + os_environ = os.environ.copy() + # Turn on feature flag + os_environ[PYTHON_ENABLE_INIT_INDEXING] = '1' + cls._patch_environ = patch.dict('os.environ', os_environ) + cls._patch_environ.start() + + super().setUpClass() + + @classmethod + def tearDownClass(cls): + os.environ.clear() + os.environ.update(cls._pre_env) + cls._patch_environ.stop() + super().tearDownClass() + + @classmethod + def get_script_dir(cls): + return testutils.UNIT_TESTS_FOLDER / 'http_functions' / \ + 'http_v2_functions' / \ + 'fastapi' + + def test_return_bytes(self): + r = self.webhost.request('GET', 'return_bytes') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.content, b'"Hello World"') + self.assertEqual(r.headers['content-type'], 'application/json') + + def test_return_http_200(self): + r = self.webhost.request('GET', 'return_http') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '

Hello World™

') + self.assertEqual(r.headers['content-type'], 'text/html; charset=utf-8') + + def test_return_http_no_body(self): + r = self.webhost.request('GET', 'return_http_no_body') + self.assertEqual(r.text, '') + self.assertEqual(r.status_code, 200) + + def test_return_http_auth_level_admin(self): + r = self.webhost.request('GET', 'return_http_auth_admin', + params={'code': 'testMasterKey'}) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '

Hello World™

') + self.assertEqual(r.headers['content-type'], 'text/html; charset=utf-8') + + def test_return_http_404(self): + r = self.webhost.request('GET', 'return_http_404') + self.assertEqual(r.status_code, 404) + self.assertEqual(r.text, 'bye') + + def test_return_http_redirect(self): + r = self.webhost.request('GET', 'return_http_redirect') + self.assertEqual(r.text, '

Hello World™

') + self.assertEqual(r.status_code, 200) + + r = self.webhost.request('GET', 'return_http_redirect', + allow_redirects=False) + self.assertEqual(r.status_code, 302) + + def test_async_return_str(self): + r = self.webhost.request('GET', 'async_return_str') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"Hello Async World!"') + + def test_async_logging(self): + # Test that logging doesn't *break* things. + r = self.webhost.request('GET', 'async_logging') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-async"') + + def check_log_async_logging(self, host_out: typing.List[str]): + # Host out only contains user logs + self.assertIn('hello info', host_out) + self.assertIn('and another error', host_out) + + def test_debug_logging(self): + r = self.webhost.request('GET', 'debug_logging') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-debug"') + + def check_log_debug_logging(self, host_out: typing.List[str]): + self.assertIn('logging info', host_out) + self.assertIn('logging warning', host_out) + self.assertIn('logging error', host_out) + self.assertNotIn('logging debug', host_out) + + def test_debug_with_user_logging(self): + r = self.webhost.request('GET', 'debug_user_logging') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-user-debug"') + + def check_log_debug_with_user_logging(self, host_out: typing.List[str]): + self.assertIn('logging info', host_out) + self.assertIn('logging warning', host_out) + self.assertIn('logging debug', host_out) + self.assertIn('logging error', host_out) + + def test_sync_logging(self): + # Test that logging doesn't *break* things. + r = self.webhost.request('GET', 'sync_logging') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-sync"') + + def check_log_sync_logging(self, host_out: typing.List[str]): + # Host out only contains user logs + self.assertIn('a gracefully handled error', host_out) + + def test_return_context(self): + r = self.webhost.request('GET', 'return_context') + self.assertEqual(r.status_code, 200) + + data = r.json() + + self.assertEqual(data['method'], 'GET') + self.assertEqual(data['ctx_func_name'], 'return_context') + self.assertIn('ctx_invocation_id', data) + self.assertIn('ctx_trace_context_Tracestate', data) + self.assertIn('ctx_trace_context_Traceparent', data) + + def test_remapped_context(self): + r = self.webhost.request('GET', 'remapped_context') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"GET"') + + def test_return_request(self): + r = self.webhost.request( + 'GET', 'return_request', + params={'a': 1, 'b': ':%)'}, + headers={'xxx': 'zzz', 'Max-Forwards': '10'}) + + self.assertEqual(r.status_code, 200) + + req = r.json() + + self.assertEqual(req['method'], 'GET') + self.assertEqual(req['params'], {'a': '1', 'b': ':%)'}) + self.assertEqual(req['headers']['xxx'], 'zzz') + self.assertEqual(req['headers']['max-forwards'], '10') + + self.assertIn('return_request', req['url']) + + def test_post_return_request(self): + r = self.webhost.request( + 'POST', 'return_request', + params={'a': 1, 'b': ':%)'}, + headers={'xxx': 'zzz'}, + data={'key': 'value'}) + + self.assertEqual(r.status_code, 200) + + req = r.json() + + self.assertEqual(req['method'], 'POST') + self.assertEqual(req['params'], {'a': '1', 'b': ':%)'}) + self.assertEqual(req['headers']['xxx'], 'zzz') + + self.assertIn('return_request', req['url']) + + self.assertEqual(req['body'], 'key=value') + + def test_post_json_request_is_untouched(self): + body = b'{"foo": "bar", "two": 4}' + body_hash = hashlib.sha256(body).hexdigest() + r = self.webhost.request( + 'POST', 'return_request', + headers={'Content-Type': 'application/json'}, + data=body) + + self.assertEqual(r.status_code, 200) + req = r.json() + self.assertEqual(req['body_hash'], body_hash) + + def test_accept_json(self): + r = self.webhost.request( + 'GET', 'accept_json', + json={'a': 'abc', 'd': 42}) + + self.assertEqual(r.status_code, 200) + r_json = r.json() + self.assertEqual(r_json, {'a': 'abc', 'd': 42}) + self.assertEqual(r.headers['content-type'], 'application/json') + + def test_unhandled_error(self): + r = self.webhost.request('GET', 'unhandled_error') + self.assertEqual(r.status_code, 500) + # https://github.com/Azure/azure-functions-host/issues/2706 + # self.assertIn('Exception: ZeroDivisionError', r.text) + + def check_log_unhandled_error(self, + host_out: typing.List[str]): + error_substring = 'ZeroDivisionError: division by zero' + for item in host_out: + if error_substring in item: + break + else: + self.fail( + f"{error_substring}' not found in host log.") + + def test_unhandled_urllib_error(self): + r = self.webhost.request( + 'GET', 'unhandled_urllib_error', + params={'img': 'http://example.com/nonexistent.jpg'}) + self.assertEqual(r.status_code, 500) + + def test_unhandled_unserializable_error(self): + r = self.webhost.request( + 'GET', 'unhandled_unserializable_error') + self.assertEqual(r.status_code, 500) + + def test_return_route_params(self): + r = self.webhost.request('GET', 'return_route_params/foo/bar') + self.assertEqual(r.status_code, 200) + resp = r.json() + self.assertEqual(resp, {'param1': 'foo', 'param2': 'bar'}) + + def test_raw_body_bytes(self): + parent_dir = pathlib.Path(__file__).parent + image_file = parent_dir / 'resources/functions.png' + with open(image_file, 'rb') as image: + img = image.read() + img_len = len(img) + r = self.webhost.request('POST', 'raw_body_bytes', data=img) + + received_body_len = int(r.headers['body-len']) + self.assertEqual(received_body_len, img_len) + + body = r.content + try: + received_img_file = parent_dir / 'received_img.png' + with open(received_img_file, 'wb') as received_img: + received_img.write(body) + self.assertTrue(filecmp.cmp(received_img_file, image_file)) + finally: + if (os.path.exists(received_img_file)): + os.remove(received_img_file) + + def test_image_png_content_type(self): + parent_dir = pathlib.Path(__file__).parent + image_file = parent_dir / 'resources/functions.png' + with open(image_file, 'rb') as image: + img = image.read() + img_len = len(img) + r = self.webhost.request( + 'POST', 'raw_body_bytes', + headers={'Content-Type': 'image/png'}, + data=img) + + received_body_len = int(r.headers['body-len']) + self.assertEqual(received_body_len, img_len) + + body = r.content + try: + received_img_file = parent_dir / 'received_img.png' + with open(received_img_file, 'wb') as received_img: + received_img.write(body) + self.assertTrue(filecmp.cmp(received_img_file, image_file)) + finally: + if (os.path.exists(received_img_file)): + os.remove(received_img_file) + + def test_application_octet_stream_content_type(self): + parent_dir = pathlib.Path(__file__).parent + image_file = parent_dir / 'resources/functions.png' + with open(image_file, 'rb') as image: + img = image.read() + img_len = len(img) + r = self.webhost.request( + 'POST', 'raw_body_bytes', + headers={'Content-Type': 'application/octet-stream'}, + data=img) + + received_body_len = int(r.headers['body-len']) + self.assertEqual(received_body_len, img_len) + + body = r.content + try: + received_img_file = parent_dir / 'received_img.png' + with open(received_img_file, 'wb') as received_img: + received_img.write(body) + self.assertTrue(filecmp.cmp(received_img_file, image_file)) + finally: + if (os.path.exists(received_img_file)): + os.remove(received_img_file) + + def test_user_event_loop_error(self): + # User event loop is not supported in HTTP trigger + r = self.webhost.request('GET', 'user_event_loop/') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-user-event-loop"') + + def check_log_user_event_loop_error(self, host_out: typing.List[str]): + self.assertIn('try_log', host_out) + + def check_log_import_module_troubleshooting_url(self, + host_out: typing.List[str]): + passed = False + exception_message = "Exception: ModuleNotFoundError: "\ + "No module named 'does_not_exist'. "\ + "Cannot find module. "\ + "Please check the requirements.txt file for the "\ + "missing module. For more info, please refer the "\ + "troubleshooting guide: "\ + "https://aka.ms/functions-modulenotfound. "\ + "Current sys.path: " + for log in host_out: + if exception_message in log: + passed = True + self.assertTrue(passed) + + def test_print_logging_no_flush(self): + r = self.webhost.request('GET', 'print_logging?message=Secret42') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-print-logging"') + + def check_log_print_logging_no_flush(self, host_out: typing.List[str]): + self.assertIn('Secret42', host_out) + + def test_print_logging_with_flush(self): + r = self.webhost.request('GET', + 'print_logging?flush=true&message=Secret42') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-print-logging"') + + def check_log_print_logging_with_flush(self, host_out: typing.List[str]): + self.assertIn('Secret42', host_out) + + def test_print_to_console_stdout(self): + r = self.webhost.request('GET', + 'print_logging?console=true&message=Secret42') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-print-logging"') + + @skipIf(sys.version_info < (3, 8, 0), + "Skip the tests for Python 3.7 and below") + def test_multiple_cookie_header_in_response(self): + r = self.webhost.request('GET', 'multiple_set_cookie_resp_headers') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.headers.get( + 'Set-Cookie'), + "foo3=42; Domain=example.com; expires=Thu, 12 Jan 2017 13:55:08" + " GMT; HttpOnly; Max-Age=10000000; Path=/; SameSite=Lax; Secure," + " foo3=43; Domain=example.com; expires=Fri, 12 Jan 2018 13:55:08" + " GMT; HttpOnly; Max-Age=10000000; Path=/; SameSite=Lax; Secure") + + @skipIf(sys.version_info < (3, 8, 0), + "Skip the tests for Python 3.7 and below") + def test_set_cookie_header_in_response_default_value(self): + r = self.webhost.request('GET', + 'set_cookie_resp_header_default_values') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.headers.get('Set-Cookie'), + 'foo3=42; Path=/; SameSite=lax') + + @skipIf(sys.version_info < (3, 8, 0), + "Skip the tests for Python 3.7 and below") + def test_response_cookie_header_nullable_timestamp_err(self): + r = self.webhost.request( + 'GET', + 'response_cookie_header_nullable_timestamp_err') + self.assertEqual(r.status_code, 200) + + @skipIf(sys.version_info < (3, 8, 0), + "Skip the tests for Python 3.7 and below") + def test_response_cookie_header_nullable_bool_err(self): + r = self.webhost.request( + 'GET', + 'response_cookie_header_nullable_bool_err') + self.assertEqual(r.status_code, 200) + self.assertTrue("Set-Cookie" in r.headers) + + def test_print_to_console_stderr(self): + r = self.webhost.request('GET', 'print_logging?console=true' + '&message=Secret42&is_stderr=true') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-print-logging"') + + def check_log_print_to_console_stderr(self, host_out: typing.List[str], ): + # System logs stderr should not exist in host_out + self.assertNotIn('Secret42', host_out) + + def test_hijack_current_event_loop(self): + r = self.webhost.request('GET', 'hijack_current_event_loop/') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"OK-hijack-current-event-loop"') + + def check_log_hijack_current_event_loop(self, host_out: typing.List[str]): + # User logs should exist in host_out + self.assertIn('parallelly_print', host_out) + self.assertIn('parallelly_log_info at root logger', host_out) + self.assertIn('parallelly_log_warning at root logger', host_out) + self.assertIn('parallelly_log_error at root logger', host_out) + self.assertIn('parallelly_log_exception at root logger', host_out) + self.assertIn('parallelly_log_custom at custom_logger', host_out) + self.assertIn('callsoon_log', host_out) + + # System logs should not exist in host_out + self.assertNotIn('parallelly_log_system at disguised_logger', host_out) + + def test_no_type_hint(self): + r = self.webhost.request('GET', 'no_type_hint') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '"no_type_hint"') + + def test_return_int(self): + r = self.webhost.request('GET', 'return_int') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '1000') + + def test_return_float(self): + r = self.webhost.request('GET', 'return_float') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '1000.0') + + def test_return_bool(self): + r = self.webhost.request('GET', 'return_bool') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, 'true') + + def test_return_dict(self): + r = self.webhost.request('GET', 'return_dict') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.json(), {'key': 'value'}) + + def test_return_list(self): + r = self.webhost.request('GET', 'return_list') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.json(), ["value1", "value2"]) + + def test_return_pydantic_model(self): + r = self.webhost.request('GET', 'return_pydantic_model') + self.assertEqual(r.status_code, 200) + self.assertEqual(r.json(), {'description': 'description1', + 'name': 'item1'}) + + def test_return_pydantic_model_with_missing_fields(self): + r = self.webhost.request('GET', + 'return_pydantic_model_with_missing_fields') + self.assertEqual(r.status_code, 500) + + def check_return_pydantic_model_with_missing_fields(self, + host_out: + typing.List[str]): + self.assertIn("Field required [type=missing, input_value={'name': " + "'item1'}, input_type=dict]", host_out) diff --git a/tests/unittests/test_http_v2.py b/tests/unittests/test_http_v2.py new file mode 100644 index 000000000..1ab7af5e7 --- /dev/null +++ b/tests/unittests/test_http_v2.py @@ -0,0 +1,249 @@ +import asyncio +import socket +import sys +import unittest +from unittest.mock import MagicMock, patch + +from azure_functions_worker.http_v2 import http_coordinator, \ + AsyncContextReference, SingletonMeta, get_unused_tcp_port + + +class MockHttpRequest: + pass + + +class MockHttpResponse: + pass + + +@unittest.skipIf(sys.version_info <= (3, 7), "Skipping tests if <= Python 3.7") +class TestHttpCoordinator(unittest.TestCase): + def setUp(self): + self.invoc_id = "test_invocation" + self.http_request = MockHttpRequest() + self.http_response = MockHttpResponse() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self) -> None: + http_coordinator._context_references.clear() + self.loop.close() + + def test_set_http_request_new_invocation(self): + # Test setting a new HTTP request + http_coordinator.set_http_request(self.invoc_id, self.http_request) + context_ref = http_coordinator._context_references.get(self.invoc_id) + self.assertIsNotNone(context_ref) + self.assertEqual(context_ref.http_request, self.http_request) + + def test_set_http_request_existing_invocation(self): + # Test updating an existing HTTP request + new_http_request = MagicMock() + http_coordinator.set_http_request(self.invoc_id, new_http_request) + context_ref = http_coordinator._context_references.get(self.invoc_id) + self.assertIsNotNone(context_ref) + self.assertEqual(context_ref.http_request, new_http_request) + + def test_set_http_response_context_ref_null(self): + with self.assertRaises(Exception) as cm: + http_coordinator.set_http_response(self.invoc_id, + self.http_response) + self.assertEqual(cm.exception.args[0], + "No context reference found for invocation " + f"{self.invoc_id}") + + def test_set_http_response(self): + http_coordinator.set_http_request(self.invoc_id, self.http_request) + http_coordinator.set_http_response(self.invoc_id, self.http_response) + context_ref = http_coordinator._context_references[self.invoc_id] + self.assertEqual(context_ref.http_response, self.http_response) + + def test_get_http_request_async_existing_invocation(self): + # Test retrieving an existing HTTP request + http_coordinator.set_http_request(self.invoc_id, + self.http_request) + retrieved_request = self.loop.run_until_complete( + http_coordinator.get_http_request_async(self.invoc_id)) + self.assertEqual(retrieved_request, self.http_request) + + def test_get_http_request_async_wait_forever(self): + # Test handling error when invoc_id is not found + invalid_invoc_id = "invalid_invocation" + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete( + asyncio.wait_for( + http_coordinator.get_http_request_async( + invalid_invoc_id), + timeout=1 + ) + ) + + def test_await_http_response_async_valid_invocation(self): + invoc_id = "valid_invocation" + expected_response = self.http_response + + context_ref = AsyncContextReference(http_response=expected_response) + + # Add the mock context reference to the coordinator + http_coordinator._context_references[invoc_id] = context_ref + + http_coordinator.set_http_response(invoc_id, expected_response) + + # Call the method and verify the returned response + response = self.loop.run_until_complete( + http_coordinator.await_http_response_async(invoc_id)) + self.assertEqual(response, expected_response) + self.assertTrue( + http_coordinator._context_references.get( + invoc_id).http_response is None) + + def test_await_http_response_async_invalid_invocation(self): + # Test handling error when invoc_id is not found + invalid_invoc_id = "invalid_invocation" + with self.assertRaises(Exception) as context: + self.loop.run_until_complete( + http_coordinator.await_http_response_async(invalid_invoc_id)) + self.assertEqual(str(context.exception), + f"No context reference found for invocation " + f"{invalid_invoc_id}") + + def test_await_http_response_async_response_not_set(self): + invoc_id = "invocation_with_no_response" + # Set up a mock context reference without setting the response + context_ref = AsyncContextReference() + + # Add the mock context reference to the coordinator + http_coordinator._context_references[invoc_id] = context_ref + + http_coordinator.set_http_response(invoc_id, None) + # Call the method and verify that it raises an exception + with self.assertRaises(Exception) as context: + self.loop.run_until_complete( + http_coordinator.await_http_response_async(invoc_id)) + self.assertEqual(str(context.exception), + f"No http response found for invocation {invoc_id}") + + +@unittest.skipIf(sys.version_info <= (3, 7), "Skipping tests if <= Python 3.7") +class TestAsyncContextReference(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self) -> None: + self.loop.close() + + def test_init(self): + ref = AsyncContextReference() + self.assertIsInstance(ref, AsyncContextReference) + self.assertTrue(ref.is_async) + + def test_http_request_property(self): + ref = AsyncContextReference() + ref.http_request = object() + self.assertIsNotNone(ref.http_request) + + def test_http_response_property(self): + ref = AsyncContextReference() + ref.http_response = object() + self.assertIsNotNone(ref.http_response) + + def test_function_property(self): + ref = AsyncContextReference() + ref.function = object() + self.assertIsNotNone(ref.function) + + def test_fi_context_property(self): + ref = AsyncContextReference() + ref.fi_context = object() + self.assertIsNotNone(ref.fi_context) + + def test_http_trigger_param_name_property(self): + ref = AsyncContextReference() + ref.http_trigger_param_name = object() + self.assertIsNotNone(ref.http_trigger_param_name) + + def test_args_property(self): + ref = AsyncContextReference() + ref.args = object() + self.assertIsNotNone(ref.args) + + def test_http_request_available_event_property(self): + ref = AsyncContextReference() + self.assertIsNotNone(ref.http_request_available_event) + + def test_http_response_available_event_property(self): + ref = AsyncContextReference() + self.assertIsNotNone(ref.http_response_available_event) + + def test_full_args(self): + ref = AsyncContextReference(http_request=object(), + http_response=object(), + function=object(), + fi_context=object(), + args=object()) + self.assertIsNotNone(ref.http_request) + self.assertIsNotNone(ref.http_response) + self.assertIsNotNone(ref.function) + self.assertIsNotNone(ref.fi_context) + self.assertIsNotNone(ref.args) + + +class TestSingletonMeta(unittest.TestCase): + + def test_singleton_instance(self): + class TestClass(metaclass=SingletonMeta): + pass + + obj1 = TestClass() + obj2 = TestClass() + + self.assertIs(obj1, obj2) + + def test_singleton_with_arguments(self): + class TestClass(metaclass=SingletonMeta): + def __init__(self, arg): + self.arg = arg + + obj1 = TestClass(1) + obj2 = TestClass(2) + + self.assertEqual(obj1.arg, 1) + self.assertEqual(obj2.arg, + 1) # Should still refer to the same instance + + def test_singleton_with_kwargs(self): + class TestClass(metaclass=SingletonMeta): + def __init__(self, **kwargs): + self.kwargs = kwargs + + obj1 = TestClass(a=1) + obj2 = TestClass(b=2) + + self.assertEqual(obj1.kwargs, {'a': 1}) + self.assertEqual(obj2.kwargs, + {'a': 1}) # Should still refer to the same instance + + +class TestGetUnusedTCPPort(unittest.TestCase): + + @patch('socket.socket') + def test_get_unused_tcp_port(self, mock_socket): + # Mock the socket object and its methods + mock_socket_instance = mock_socket.return_value + mock_socket_instance.getsockname.return_value = ('localhost', 12345) + + # Call the function + port = get_unused_tcp_port() + + # Assert that socket.socket was called with the correct arguments + mock_socket.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) + + # Assert that bind and close methods were called on the socket instance + mock_socket_instance.bind.assert_called_once_with(('', 0)) + mock_socket_instance.close.assert_called_once() + + # Assert that the returned port matches the expected value + self.assertEqual(port, 12345) diff --git a/tests/utils/testutils.py b/tests/utils/testutils.py index 57946f1eb..bdc305291 100644 --- a/tests/utils/testutils.py +++ b/tests/utils/testutils.py @@ -529,11 +529,14 @@ def worker_id(self): def request_id(self): return self._request_id - async def init_worker(self, host_version: str = '4.28.0'): + async def init_worker(self, host_version: str = '4.28.0', **kwargs): + include_func_app_dir = kwargs.get('include_func_app_dir', False) r = await self.communicate( protos.StreamingMessage( worker_init_request=protos.WorkerInitRequest( - host_version=host_version + host_version=host_version, + function_app_directory=str( + self._scripts_dir) if include_func_app_dir else None, ) ), wait_for='worker_init_response' diff --git a/tests/utils/testutils_lc.py b/tests/utils/testutils_lc.py index 10c5079a9..84311a6c1 100644 --- a/tests/utils/testutils_lc.py +++ b/tests/utils/testutils_lc.py @@ -32,6 +32,8 @@ "/archive/refs/heads/dev.zip" _FUNC_FILE_NAME = "azure-functions-python-library-dev" _CUSTOM_IMAGE = "CUSTOM_IMAGE" +_EXTENSION_BASE_ZIP = 'https://github.com/Azure/azure-functions-python-' \ + 'extensions/archive/refs/heads/dev.zip' class LinuxConsumptionWebHostController: @@ -151,6 +153,15 @@ def _download_azure_functions() -> str: with ZipFile(BytesIO(zipresp.read())) as zfile: zfile.extractall(tempfile.gettempdir()) + @staticmethod + def _download_extensions() -> str: + folder = tempfile.gettempdir() + with urlopen(_EXTENSION_BASE_ZIP) as zipresp: + with ZipFile(BytesIO(zipresp.read())) as zfile: + zfile.extractall(folder) + + return folder + def spawn_container(self, image: str, env: Dict[str, str] = {}) -> int: @@ -163,11 +174,24 @@ def spawn_container(self, # TODO: Mount library in docker container # self._download_azure_functions() + # Download python extension base package + ext_folder = self._download_extensions() + container_worker_path = ( f"/azure-functions-host/workers/python/{self._py_version}/" "LINUX/X64/azure_functions_worker" ) + base_ext_container_path = ( + f"/azure-functions-host/workers/python/{self._py_version}/" + "LINUX/X64/azure/functions/extension/base" + ) + + base_ext_local_path = ( + f'{ext_folder}\\azure-functions-python' + f'-extensions-dev\\azure-functions-extension-base' + f'\\azure\\functions\\extension\\base' + ) run_cmd = [] run_cmd.extend([self._docker_cmd, "run", "-p", "0:80", "-d"]) run_cmd.extend(["--name", self._uuid, "--privileged"]) @@ -177,6 +201,8 @@ def spawn_container(self, run_cmd.extend(["-e", f"CONTAINER_ENCRYPTION_KEY={_DUMMY_CONT_KEY}"]) run_cmd.extend(["-e", "WEBSITE_PLACEHOLDER_MODE=1"]) run_cmd.extend(["-v", f'{worker_path}:{container_worker_path}']) + run_cmd.extend(["-v", + f'{base_ext_local_path}:{base_ext_container_path}']) for key, value in env.items(): run_cmd.extend(["-e", f"{key}={value}"])