diff --git a/azure_functions_worker/bindings/generic.py b/azure_functions_worker/bindings/generic.py index 9d0cca8af..bc886dee0 100644 --- a/azure_functions_worker/bindings/generic.py +++ b/azure_functions_worker/bindings/generic.py @@ -51,5 +51,7 @@ def decode(cls, data: datumdef.Datum, *, trigger_metadata) -> typing.Any: return result @classmethod - def has_implicit_output(cls) -> bool: - return False + def has_implicit_output(cls, bind_name: Optional[str]) -> bool: + if bind_name == 'durableClient': + return False + return True diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index 3f52f8d0f..f7a810145 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -55,9 +55,15 @@ def check_output_type_annotation(bind_name: str, pytype: type) -> bool: def has_implicit_output(bind_name: str) -> bool: binding = get_binding(bind_name) - # If the binding does not have metaclass of meta.InConverter - # The implicit_output does not exist - return getattr(binding, 'has_implicit_output', lambda: False)() + # Need to pass in bind_name to exempt Durable Functions + if binding is generic.GenericBinding: + return (getattr(binding, 'has_implicit_output', lambda: False) + (bind_name)) + + else: + # If the binding does not have metaclass of meta.InConverter + # The implicit_output does not exist + return getattr(binding, 'has_implicit_output', lambda: False)() def from_incoming_proto( diff --git a/azure_functions_worker/functions.py b/azure_functions_worker/functions.py index 4ad774fff..f0926230c 100644 --- a/azure_functions_worker/functions.py +++ b/azure_functions_worker/functions.py @@ -71,14 +71,20 @@ def get_explicit_and_implicit_return(binding_name: str, @staticmethod def get_return_binding(binding_name: str, binding_type: str, - return_binding_name: str) -> str: + return_binding_name: str, + explicit_return_val_set: bool) \ + -> typing.Tuple[str, bool]: + # prioritize explicit return value + if explicit_return_val_set: + return return_binding_name, explicit_return_val_set if binding_name == "$return": return_binding_name = binding_type assert return_binding_name is not None + explicit_return_val_set = True elif bindings_utils.has_implicit_output(binding_type): return_binding_name = binding_type - return return_binding_name + return return_binding_name, explicit_return_val_set @staticmethod def validate_binding_direction(binding_name: str, @@ -314,6 +320,7 @@ def add_function(self, function_id: str, params = dict(sig.parameters) annotations = typing.get_type_hints(func) return_binding_name: typing.Optional[str] = None + explicit_return_val_set = False has_explicit_return = False has_implicit_return = False @@ -327,9 +334,11 @@ def add_function(self, function_id: str, binding_name, binding_info, has_explicit_return, has_implicit_return, bound_params) - return_binding_name = self.get_return_binding(binding_name, - binding_info.type, - return_binding_name) + return_binding_name, explicit_return_val_set = \ + self.get_return_binding(binding_name, + binding_info.type, + return_binding_name, + explicit_return_val_set) requires_context = self.is_context_required(params, bound_params, annotations, @@ -362,6 +371,7 @@ def add_indexed_function(self, function): function_id = str(uuid.uuid5(namespace=uuid.NAMESPACE_OID, name=func_name)) return_binding_name: typing.Optional[str] = None + explicit_return_val_set = False has_explicit_return = False has_implicit_return = False @@ -381,9 +391,11 @@ def add_indexed_function(self, function): binding.name, binding, has_explicit_return, has_implicit_return, bound_params) - return_binding_name = self.get_return_binding(binding.name, - binding.type, - return_binding_name) + return_binding_name, explicit_return_val_set = \ + self.get_return_binding(binding.name, + binding.type, + return_binding_name, + explicit_return_val_set) requires_context = self.is_context_required(params, bound_params, annotations, diff --git a/tests/endtoend/generic_functions/generic_functions_stein/function_app.py b/tests/endtoend/generic_functions/generic_functions_stein/function_app.py new file mode 100644 index 000000000..47c74f862 --- /dev/null +++ b/tests/endtoend/generic_functions/generic_functions_stein/function_app.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import azure.functions as func + +app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS) + + +@app.function_name(name="return_processed_last") +@app.generic_trigger(arg_name="req", type="httpTrigger", + route="return_processed_last") +@app.generic_output_binding(arg_name="$return", type="http") +@app.generic_input_binding( + arg_name="testEntity", + type="table", + connection="AzureWebJobsStorage", + table_name="EventHubBatchTest") +def return_processed_last(req: func.HttpRequest, testEntity): + return func.HttpResponse(status_code=200) + + +@app.function_name(name="return_not_processed_last") +@app.generic_trigger(arg_name="req", type="httpTrigger", + route="return_not_processed_last") +@app.generic_output_binding(arg_name="$return", type="http") +@app.generic_input_binding( + arg_name="testEntities", + type="table", + connection="AzureWebJobsStorage", + table_name="EventHubBatchTest") +def return_not_processed_last(req: func.HttpRequest, testEntities): + return func.HttpResponse(status_code=200) diff --git a/tests/endtoend/generic_functions/return_not_processed_last/__init__.py b/tests/endtoend/generic_functions/return_not_processed_last/__init__.py new file mode 100644 index 000000000..300fae398 --- /dev/null +++ b/tests/endtoend/generic_functions/return_not_processed_last/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import azure.functions as func + + +# There are 3 bindings defined in function.json: +# 1. req: HTTP trigger +# 2. testEntities: table input (generic) +# 3. $return: HTTP response +# The bindings will be processed by the worker in this order: +# req -> $return -> testEntities +def main(req: func.HttpRequest, testEntities): + return func.HttpResponse(status_code=200) diff --git a/tests/endtoend/generic_functions/return_not_processed_last/function.json b/tests/endtoend/generic_functions/return_not_processed_last/function.json new file mode 100644 index 000000000..66d1e80e1 --- /dev/null +++ b/tests/endtoend/generic_functions/return_not_processed_last/function.json @@ -0,0 +1,26 @@ +{ + "scriptFile": "__init__.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "authLevel": "anonymous", + "methods": [ + "get" + ], + "name": "req" + }, + { + "direction": "in", + "type": "table", + "name": "testEntities", + "tableName": "EventHubBatchTest", + "connection": "AzureWebJobsStorage" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} \ No newline at end of file diff --git a/tests/endtoend/generic_functions/return_processed_last/__init__.py b/tests/endtoend/generic_functions/return_processed_last/__init__.py new file mode 100644 index 000000000..3d8f56122 --- /dev/null +++ b/tests/endtoend/generic_functions/return_processed_last/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import azure.functions as func + + +# There are 3 bindings defined in function.json: +# 1. req: HTTP trigger +# 2. testEntity: table input (generic) +# 3. $return: HTTP response +# The bindings will be processed by the worker in this order: +# req -> testEntity -> $return +def main(req: func.HttpRequest, testEntity): + return func.HttpResponse(status_code=200) diff --git a/tests/endtoend/generic_functions/return_processed_last/function.json b/tests/endtoend/generic_functions/return_processed_last/function.json new file mode 100644 index 000000000..82ac266a6 --- /dev/null +++ b/tests/endtoend/generic_functions/return_processed_last/function.json @@ -0,0 +1,26 @@ +{ + "scriptFile": "__init__.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "authLevel": "anonymous", + "methods": [ + "get" + ], + "name": "req" + }, + { + "direction": "in", + "type": "table", + "name": "testEntity", + "tableName": "EventHubBatchTest", + "connection": "AzureWebJobsStorage" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} \ No newline at end of file diff --git a/tests/endtoend/test_generic_functions.py b/tests/endtoend/test_generic_functions.py new file mode 100644 index 000000000..8be60f669 --- /dev/null +++ b/tests/endtoend/test_generic_functions.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from unittest import skipIf + +from azure_functions_worker.utils.common import is_envvar_true +from tests.utils import testutils +from tests.utils.constants import DEDICATED_DOCKER_TEST, CONSUMPTION_DOCKER_TEST + + +@skipIf(is_envvar_true(DEDICATED_DOCKER_TEST) + or is_envvar_true(CONSUMPTION_DOCKER_TEST), + "Table functions which are used in the bindings in these tests" + " has a bug with the table extension 1.0.0. " + "https://github.com/Azure/azure-sdk-for-net/issues/33902.") +class TestGenericFunctions(testutils.WebHostTestCase): + """Test Generic Functions with implicit output enabled + + With implicit output enabled for generic types, these tests cover + scenarios where a function has both explicit and implicit output + set to true. We prioritize explicit output. These tests check + that no matter the ordering, the return type is still correctly set. + """ + + @classmethod + def get_script_dir(cls): + return testutils.E2E_TESTS_FOLDER / 'generic_functions' + + def test_return_processed_last(self): + # Tests the case where implicit and explicit return are true + # in the same function and $return is processed before + # the generic binding is + + r = self.webhost.request('GET', 'return_processed_last') + self.assertEqual(r.status_code, 200) + + def test_return_not_processed_last(self): + # Tests the case where implicit and explicit return are true + # in the same function and the generic binding is processed + # before $return + + r = self.webhost.request('GET', 'return_not_processed_last') + self.assertEqual(r.status_code, 200) + + +@skipIf(is_envvar_true(DEDICATED_DOCKER_TEST) + or is_envvar_true(CONSUMPTION_DOCKER_TEST), + "Table functions has a bug with the table extension 1.0.0." + "https://github.com/Azure/azure-sdk-for-net/issues/33902.") +class TestGenericFunctionsStein(TestGenericFunctions): + + @classmethod + def get_script_dir(cls): + return testutils.E2E_TESTS_FOLDER / 'generic_functions' / \ + 'generic_functions_stein' diff --git a/tests/unittests/generic_functions/foobar_implicit_output_exemption/function.json b/tests/unittests/generic_functions/foobar_implicit_output_exemption/function.json new file mode 100644 index 000000000..82a015bbb --- /dev/null +++ b/tests/unittests/generic_functions/foobar_implicit_output_exemption/function.json @@ -0,0 +1,12 @@ +{ + "scriptFile": "main.py", + "bindings": [ + { + "type": "durableClient", + "name": "input", + "direction": "in", + "dataType": "string" + } + ] + } + \ No newline at end of file diff --git a/tests/unittests/generic_functions/foobar_implicit_output_exemption/main.py b/tests/unittests/generic_functions/foobar_implicit_output_exemption/main.py new file mode 100644 index 000000000..53124993e --- /dev/null +++ b/tests/unittests/generic_functions/foobar_implicit_output_exemption/main.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# Input as string, without annotation + + +def main(input: str): + return input diff --git a/tests/unittests/test_code_quality.py b/tests/unittests/test_code_quality.py index 4ac19bd52..40302ea3e 100644 --- a/tests/unittests/test_code_quality.py +++ b/tests/unittests/test_code_quality.py @@ -36,7 +36,7 @@ def test_flake8(self): try: import flake8 # NoQA except ImportError as e: - raise unittest.SkipTest('flake8 moudule is missing') from e + raise unittest.SkipTest('flake8 module is missing') from e config_path = ROOT_PATH / '.flake8' if not config_path.exists(): diff --git a/tests/unittests/test_mock_generic_functions.py b/tests/unittests/test_mock_generic_functions.py index 32004850f..238837d89 100644 --- a/tests/unittests/test_mock_generic_functions.py +++ b/tests/unittests/test_mock_generic_functions.py @@ -119,7 +119,7 @@ async def test_mock_generic_as_bytes_no_anno(self): protos.TypedData(bytes=b'\x00\x01') ) - async def test_mock_generic_should_not_support_implicit_output(self): + async def test_mock_generic_should_support_implicit_output(self): async with testutils.start_mockhost( script_root=self.generic_funcs_dir) as host: @@ -131,7 +131,7 @@ async def test_mock_generic_should_not_support_implicit_output(self): protos.StatusResult.Success) _, r = await host.invoke_function( - 'foobar_as_bytes_no_anno', [ + 'foobar_implicit_output', [ protos.ParameterBinding( name='input', data=protos.TypedData( @@ -140,10 +140,10 @@ async def test_mock_generic_should_not_support_implicit_output(self): ) ] ) - # It should fail here, since generic binding requires - # $return statement in function.json to pass output + # It passes now as we are enabling generic binding to return output + # implicitly self.assertEqual(r.response.result.status, - protos.StatusResult.Failure) + protos.StatusResult.Success) async def test_mock_generic_should_support_without_datatype(self): async with testutils.start_mockhost( @@ -166,7 +166,32 @@ async def test_mock_generic_should_support_without_datatype(self): ) ] ) - # It should fail here, since the generic binding requires datatype - # to be defined in function.json + # It passes now as we are enabling generic binding to return output + # implicitly + self.assertEqual(r.response.result.status, + protos.StatusResult.Success) + + async def test_mock_generic_implicit_output_exemption(self): + async with testutils.start_mockhost( + script_root=self.generic_funcs_dir) as host: + await host.init_worker("4.17.1") + func_id, r = await host.load_function( + 'foobar_implicit_output_exemption') + self.assertEqual(r.response.function_id, func_id) + self.assertEqual(r.response.result.status, + protos.StatusResult.Success) + + _, r = await host.invoke_function( + 'foobar_implicit_output_exemption', [ + protos.ParameterBinding( + name='input', + data=protos.TypedData( + bytes=b'\x00\x01' + ) + ) + ] + ) + # It should fail here, since implicit output is False + # For the Durable Functions durableClient case self.assertEqual(r.response.result.status, protos.StatusResult.Failure)