Skip to content

Adding more Typing information as generated using monkeytype #709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 30, 2020
4 changes: 2 additions & 2 deletions azure_functions_worker/_thirdparty/aio_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Backport of asyncio.run() function from Python 3.7.

Source: https://github.com/python/cpython/blob/
bd093355a6aaf2f4ca3ed153e195da57870a55eb/Lib/asyncio/runners.py
bd093355a6aaf2f4ca3ed153e195da57870a55eb/Lib/asyncio/runners.py
"""


Expand All @@ -13,7 +13,7 @@ def get_running_loop():

This function is thread-specific.
"""
loop = asyncio._get_running_loop()
loop = asyncio.events.get_running_loop()
if loop is None:
raise RuntimeError('no running event loop')
return loop
Expand Down
5 changes: 3 additions & 2 deletions azure_functions_worker/bindings/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

from . import datumdef
from typing import Any, Optional


class GenericBinding:
Expand All @@ -20,8 +21,8 @@ def check_output_type_annotation(cls, pytype: type) -> bool:
return issubclass(pytype, (str, bytes, bytearray))

@classmethod
def encode(cls, obj: typing.Any, *,
expected_type: typing.Optional[type]) -> datumdef.Datum:
def encode(cls, obj: Any, *,
expected_type: Optional[type]) -> datumdef.Datum:
if isinstance(obj, str):
return datumdef.Datum(type='string', value=obj)

Expand Down
4 changes: 2 additions & 2 deletions azure_functions_worker/bindings/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class Out:

def __init__(self):
def __init__(self) -> None:
self.__value = None

def set(self, val):
self.__value = val

def get(self):
def get(self) -> str:
return self.__value
29 changes: 17 additions & 12 deletions azure_functions_worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from .logging import enable_console_logging, disable_console_logging
from .utils.tracing import marshall_exception_trace
from .utils.wrappers import disable_feature_by
from asyncio.unix_events import _UnixSelectorEventLoop
from logging import LogRecord
from typing import Any, Dict, Optional


class DispatcherMeta(type):
Expand All @@ -46,8 +49,10 @@ class Dispatcher(metaclass=DispatcherMeta):

_GRPC_STOP_RESPONSE = object()

def __init__(self, loop, host, port: int, worker_id: str, request_id: str,
grpc_connect_timeout: float, grpc_max_msg_len: int = -1):
def __init__(self, loop: _UnixSelectorEventLoop, host: str, port: int,
worker_id: str, request_id: str,
grpc_connect_timeout: float,
grpc_max_msg_len: int = -1) -> None:
self._loop = loop
self._host = host
self._port = port
Expand Down Expand Up @@ -78,7 +83,7 @@ def __init__(self, loop, host, port: int, worker_id: str, request_id: str,
name='grpc-thread', target=self.__poll_grpc)

@staticmethod
def load_bindings():
def load_bindings() -> Dict[Any, Any]:
"""Load out-of-tree binding implementations."""
services = {}

Expand All @@ -89,8 +94,8 @@ def load_bindings():
return services

@classmethod
async def connect(cls, host, port, worker_id, request_id,
connect_timeout):
async def connect(cls, host: str, port: int, worker_id: str,
request_id: str, connect_timeout: float):
loop = asyncio.events.get_event_loop()
disp = cls(loop, host, port, worker_id, request_id, connect_timeout)
disp._grpc_thread.start()
Expand Down Expand Up @@ -144,7 +149,7 @@ async def dispatch_forever(self):
self._loop.set_task_factory(self._old_task_factory)
self.stop()

def stop(self):
def stop(self) -> None:
if self._grpc_thread is not None:
self._grpc_resp_queue.put_nowait(self._GRPC_STOP_RESPONSE)
self._grpc_thread.join()
Expand All @@ -154,7 +159,7 @@ def stop(self):
self._sync_call_tp.shutdown()
self._sync_call_tp = None

def on_logging(self, record: logging.LogRecord, formatted_msg: str):
def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None:
if record.levelno >= logging.CRITICAL:
log_level = protos.RpcLog.Critical
elif record.levelno >= logging.ERROR:
Expand Down Expand Up @@ -196,11 +201,11 @@ def on_logging(self, record: logging.LogRecord, formatted_msg: str):
rpc_log=protos.RpcLog(**log)))

@property
def request_id(self):
def request_id(self) -> str:
return self._request_id

@property
def worker_id(self):
def worker_id(self) -> str:
return self._worker_id

# noinspection PyBroadException
Expand Down Expand Up @@ -524,7 +529,7 @@ def gen(resp_queue):

class AsyncLoggingHandler(logging.Handler):

def emit(self, record):
def emit(self, record: LogRecord) -> None:
# Since we disable console log after gRPC channel is initiated
# We should redirect all the messages into dispatcher
msg = self.format(record)
Expand All @@ -545,11 +550,11 @@ def __init__(self, coro, loop):
if invocation_id is not None:
self.set_azure_invocation_id(invocation_id)

def set_azure_invocation_id(self, invocation_id):
def set_azure_invocation_id(self, invocation_id: str) -> None:
setattr(self, self._AZURE_INVOCATION_ID, invocation_id)


def get_current_invocation_id():
def get_current_invocation_id() -> Optional[str]:
loop = asyncio._get_running_loop()
if loop is not None:
current_task = asyncio.Task.current_task(loop)
Expand Down
6 changes: 3 additions & 3 deletions azure_functions_worker/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FunctionInfo(typing.NamedTuple):

class FunctionLoadError(RuntimeError):

def __init__(self, function_name, msg):
def __init__(self, function_name: str, msg: str) -> None:
super().__init__(
f'cannot load the {function_name} function: {msg}')

Expand All @@ -41,10 +41,10 @@ class Registry:

_functions: typing.MutableMapping[str, FunctionInfo]

def __init__(self):
def __init__(self) -> None:
self._functions = {}

def get_function(self, function_id: str):
def get_function(self, function_id: str) -> FunctionInfo:
try:
return self._functions[function_id]
except KeyError:
Expand Down
9 changes: 5 additions & 4 deletions azure_functions_worker/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@

from .constants import MODULE_NOT_FOUND_TS_URL
from .utils.wrappers import attach_message_to_exception
from os import PathLike, fspath


_AZURE_NAMESPACE = '__app__'

_submodule_dirs = []


def register_function_dir(path: os.PathLike):
_submodule_dirs.append(os.fspath(path))
def register_function_dir(path: PathLike) -> None:
_submodule_dirs.append(fspath(path))


def install():
def install() -> None:
if _AZURE_NAMESPACE not in sys.modules:
# Create and register the __app__ namespace package.
ns_spec = importlib.machinery.ModuleSpec(_AZURE_NAMESPACE, None)
Expand All @@ -34,7 +35,7 @@ def install():
sys.modules[_AZURE_NAMESPACE] = ns_pkg


def uninstall():
def uninstall() -> None:
pass


Expand Down
6 changes: 3 additions & 3 deletions azure_functions_worker/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@ def setup(log_level, log_destination):
error_logger.setLevel(getattr(logging, log_level))


def disable_console_logging():
def disable_console_logging() -> None:
if logger and handler:
logger.removeHandler(handler)

if error_logger and error_handler:
error_logger.removeHandler(error_handler)


def enable_console_logging():
def enable_console_logging() -> None:
if logger and handler:
logger.addHandler(handler)

if error_logger and error_handler:
error_logger.addHandler(error_handler)


def is_system_log_category(ctg: str):
def is_system_log_category(ctg: str) -> bool:
return any(
[ctg.lower().startswith(c) for c in (
'azure_functions_worker',
Expand Down
33 changes: 16 additions & 17 deletions azure_functions_worker/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import json
import logging
import os
import queue
import pathlib
import platform
import queue
import re
import shutil
import socket
import subprocess
Expand All @@ -27,17 +28,15 @@
import typing
import unittest
import uuid
import re

import grpc
import requests

from azure_functions_worker._thirdparty import aio_compat
from . import dispatcher
from . import protos
from .utils.common import is_envvar_true
from .constants import PYAZURE_WEBHOST_DEBUG

from .utils.common import is_envvar_true

PROJECT_ROOT = pathlib.Path(__file__).parent.parent
TESTS_ROOT = PROJECT_ROOT / 'tests'
Expand Down Expand Up @@ -410,10 +409,10 @@ async def communicate(self, message, *, wait_for):
self._in_queue.put_nowait((message, wait_for))
return await self._out_aqueue.get()

async def _start(self):
async def start(self):
self._server.start()

async def _close(self):
async def close(self):
self._in_queue.put_nowait((_MockWebHostServicer._STOP, None))
self._server.stop(1)

Expand Down Expand Up @@ -454,23 +453,23 @@ def __init__(self, scripts_dir):
self._worker = None

async def __aenter__(self):
loop = aio_compat.get_running_loop()
loop = asyncio._get_running_loop()
self._host = _MockWebHost(loop, self._scripts_dir)

await self._host._start()
await self._host.start()

self._worker = await dispatcher.Dispatcher.connect(
'127.0.0.1', self._host._port,
self._host.worker_id, self._host.request_id,
connect_timeout=5.0)
self._worker = await dispatcher. \
Dispatcher.connect('127.0.0.1', self._host._port,
self._host.worker_id,
self._host.request_id, connect_timeout=5.0)

self._worker.load_bindings()

self._worker_task = loop.create_task(self._worker.dispatch_forever())

done, pending = await asyncio.wait(
[self._host._connected_fut, self._worker_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio. \
wait([self._host._connected_fut, self._worker_task],
return_when=asyncio.FIRST_COMPLETED)

try:
if self._worker_task in done:
Expand All @@ -480,7 +479,7 @@ async def __aenter__(self):
raise RuntimeError('could not start a worker thread')
except Exception:
try:
self._host._close()
await self._host.close()
self._worker.stop()
finally:
raise
Expand All @@ -498,7 +497,7 @@ async def __aexit__(self, *exc):
self._worker_task = None
self._worker = None

await self._host._close()
await self._host.close()
self._host = None


Expand Down
4 changes: 2 additions & 2 deletions azure_functions_worker/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import os


def is_true_like(setting: str):
def is_true_like(setting: str) -> bool:
if setting is None:
return False

return setting.lower().strip() in ['1', 'true', 't', 'yes', 'y']


def is_envvar_true(env_key: str):
def is_envvar_true(env_key: str) -> bool:
if os.getenv(env_key) is None:
return False

Expand Down
22 changes: 9 additions & 13 deletions azure_functions_worker/utils/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
from typing import List
import traceback
from traceback import StackSummary, extract_tb


def extend_exception_message(exc: Exception, msg: str) -> Exception:
Expand All @@ -16,28 +17,23 @@ def extend_exception_message(exc: Exception, msg: str) -> Exception:


def marshall_exception_trace(exc: Exception) -> str:
stack_summary: traceback.StackSummary = traceback.extract_tb(
exc.__traceback__)
stack_summary: StackSummary = extract_tb(exc.__traceback__)
if isinstance(exc, ModuleNotFoundError):
stack_summary = _marshall_module_not_found_error(stack_summary)
return ''.join(stack_summary.format())


def _marshall_module_not_found_error(
tbss: traceback.StackSummary
) -> traceback.StackSummary:
def _marshall_module_not_found_error(tbss: StackSummary) -> StackSummary:
tbss = _remove_frame_from_stack(tbss, '<frozen importlib._bootstrap>')
tbss = _remove_frame_from_stack(
tbss, '<frozen importlib._bootstrap_external>')
return tbss


def _remove_frame_from_stack(
tbss: traceback.StackSummary,
framename: str
) -> traceback.StackSummary:
filtered_stack_list: List[traceback.FrameSummary] = list(
filter(lambda frame: getattr(frame, 'filename') != framename, tbss))
filtered_stack: traceback.StackSummary = traceback.StackSummary.from_list(
filtered_stack_list)
def _remove_frame_from_stack(tbss: StackSummary,
framename: str) -> StackSummary:
filtered_stack_list: List[traceback.FrameSummary] = \
list(filter(lambda frame: getattr(frame,
'filename') != framename, tbss))
filtered_stack: StackSummary = StackSummary.from_list(filtered_stack_list)
return filtered_stack
Loading