Skip to content

Use importlib #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ dmypy.json

# Cython debug symbols
cython_debug/

# Test files generated
tmp*.py
34 changes: 9 additions & 25 deletions awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@
Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
"""

import importlib
import json
import logging
import os
import sys
import time
import traceback
import warnings

from .lambda_context import LambdaContext
from .lambda_runtime_client import LambdaRuntimeClient
from .lambda_runtime_exception import FaultException
from .lambda_runtime_marshaller import to_json

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import imp

ERROR_LOG_LINE_TERMINATE = "\r"
ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0

Expand All @@ -33,23 +29,14 @@ def _get_handler(handler):
)
return make_fault_handler(fault)

file_handle, pathname, desc = None, None, None
try:
# Recursively loading handler in nested directories
for segment in modname.split("."):
if pathname is not None:
pathname = [pathname]
file_handle, pathname, desc = imp.find_module(segment, pathname)
if file_handle is None:
module_type = desc[2]
if module_type == imp.C_BUILTIN:
fault = FaultException(
FaultException.BUILT_IN_MODULE_CONFLICT,
"Cannot use built-in module {} as a handler module".format(modname),
)
request_handler = make_fault_handler(fault)
return request_handler
m = imp.load_module(modname, file_handle, pathname, desc)
if modname.split(".")[0] in sys.builtin_module_names:
fault = FaultException(
FaultException.BUILT_IN_MODULE_CONFLICT,
"Cannot use built-in module {} as a handler module".format(modname),
)
return make_fault_handler(fault)
m = importlib.import_module(modname)
except ImportError as e:
fault = FaultException(
FaultException.IMPORT_MODULE_ERROR,
Expand All @@ -66,9 +53,6 @@ def _get_handler(handler):
)
request_handler = make_fault_handler(fault)
return request_handler
finally:
if file_handle is not None:
file_handle.close()

try:
request_handler = getattr(m, fname)
Expand Down Expand Up @@ -402,7 +386,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
global _GLOBAL_AWS_REQUEST_ID

request_handler = _get_handler(handler)
except Exception as e:
except Exception:
error_result = build_fault_result(sys.exc_info(), None)

log_error(error_result, log_sink)
Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ coverage>=4.4.0
flake8>=3.3.0
tox>=2.2.1
pytest-cov>=2.4.0
pylint>=1.7.2,<2.0
pylint>=1.7.2
black>=20.8b0
bandit>=1.6.2

Expand Down
160 changes: 84 additions & 76 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
"""

import importlib
import json
import os
import re
import tempfile
import traceback
import unittest
from imp import C_BUILTIN
from io import StringIO
from tempfile import NamedTemporaryFile
from unittest.mock import patch, Mock, MagicMock
from unittest.mock import MagicMock, Mock, patch

import awslambdaric.bootstrap as bootstrap
from awslambdaric.lambda_runtime_exception import FaultException
Expand Down Expand Up @@ -350,7 +350,7 @@ def __init__(self, message):

def test_handle_event_request_no_module(self):
def unable_to_import_module(json_input, lambda_context):
import invalid_module
import invalid_module # noqa: F401

expected_response = {
"errorType": "ModuleNotFoundError",
Expand Down Expand Up @@ -381,8 +381,8 @@ def unable_to_import_module(json_input, lambda_context):
def test_handle_event_request_fault_exception(self):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise FaultException(
"FaultExceptionType",
"Fault exception msg",
Expand Down Expand Up @@ -429,8 +429,8 @@ def raise_exception_handler(json_input, lambda_context):
def test_handle_event_request_fault_exception_logging(self, mock_stdout):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise bootstrap.FaultException(
"FaultExceptionType",
"Fault exception msg",
Expand Down Expand Up @@ -469,8 +469,8 @@ def raise_exception_handler(json_input, lambda_context):
def test_handle_event_request_fault_exception_logging_notrace(self, mock_stdout):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise bootstrap.FaultException(
"FaultExceptionType", "Fault exception msg", None
)
Expand All @@ -497,8 +497,8 @@ def test_handle_event_request_fault_exception_logging_nomessage_notrace(
):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise bootstrap.FaultException("FaultExceptionType", None, None)

bootstrap.handle_event_request(
Expand All @@ -523,8 +523,8 @@ def test_handle_event_request_fault_exception_logging_notype_notrace(
):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise bootstrap.FaultException(None, "Fault exception msg", None)

bootstrap.handle_event_request(
Expand All @@ -549,8 +549,8 @@ def test_handle_event_request_fault_exception_logging_notype_nomessage(
):
def raise_exception_handler(json_input, lambda_context):
try:
import invalid_module
except ImportError as e:
import invalid_module # noqa: F401
except ImportError:
raise bootstrap.FaultException(
None,
None,
Expand Down Expand Up @@ -585,19 +585,16 @@ def raise_exception_handler(json_input, lambda_context):
self.assertEqual(mock_stdout.getvalue(), error_logs)

@patch("sys.stdout", new_callable=StringIO)
@patch("imp.find_module")
@patch("imp.load_module")
@patch("importlib.import_module")
def test_handle_event_request_fault_exception_logging_syntax_error(
self, mock_load_module, mock_find_module, mock_stdout
self, mock_import_module, mock_stdout
):

try:
eval("-")
except SyntaxError as e:
syntax_error = e

mock_find_module.return_value = (None, None, ("", "", None))
mock_load_module.side_effect = syntax_error
mock_import_module.side_effect = syntax_error

response_handler = bootstrap._get_handler("a.b")

Expand All @@ -618,7 +615,10 @@ def test_handle_event_request_fault_exception_logging_syntax_error(

sys.stderr.write(mock_stdout.getvalue())

error_logs = "[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': unexpected EOF while parsing (<string>, line 1)\r"
error_logs = (
"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': "
"unexpected EOF while parsing (<string>, line 1)\r"
)
error_logs += "Traceback (most recent call last):\r"
error_logs += '  File "<string>" Line 1\r'
error_logs += "    -\n"
Expand Down Expand Up @@ -730,56 +730,57 @@ def test_get_event_handler_import_error(self):
)

def test_get_event_handler_syntax_error(self):
tmp_file = tempfile.NamedTemporaryFile(suffix=".py", dir=".", delete=False)
tmp_file.write(
b"def syntax_error()\n\tprint('syntax error, no colon after function')"
)
tmp_file.close()
filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.syntax_error".format(filename)
response_handler = bootstrap._get_handler(handler_name)

with self.assertRaises(FaultException) as cm:
response_handler()
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
"Syntax error in",
"Runtime.UserCodeSyntaxError",
".*File.*\\.py.*Line 1.*",
),
returned_exception,
)
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
importlib.invalidate_caches()
with tempfile.NamedTemporaryFile(
suffix=".py", dir=".", delete=False
) as tmp_file:
tmp_file.write(
b"def syntax_error()\n\tprint('syntax error, no colon after function')"
)
tmp_file.flush()

filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.syntax_error".format(filename)
response_handler = bootstrap._get_handler(handler_name)

with self.assertRaises(FaultException) as cm:
response_handler()
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
"Syntax error in",
"Runtime.UserCodeSyntaxError",
".*File.*\\.py.*Line 1.*",
),
returned_exception,
)

def test_get_event_handler_missing_error(self):
tmp_file = tempfile.NamedTemporaryFile(suffix=".py", dir=".", delete=False)
tmp_file.write(b"def wrong_handler_name():\n\tprint('hello')")
tmp_file.close()
filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.my_handler".format(filename)
response_handler = bootstrap._get_handler(handler_name)
with self.assertRaises(FaultException) as cm:
response_handler()
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
"Handler 'my_handler' missing on module '{}'".format(filename),
"Runtime.HandlerNotFound",
),
returned_exception,
)
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
importlib.invalidate_caches()
with tempfile.NamedTemporaryFile(
suffix=".py", dir=".", delete=False
) as tmp_file:
tmp_file.write(b"def wrong_handler_name():\n\tprint('hello')")
tmp_file.flush()

filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.my_handler".format(filename)
response_handler = bootstrap._get_handler(handler_name)
with self.assertRaises(FaultException) as cm:
response_handler()
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
"Handler 'my_handler' missing on module '{}'".format(filename),
"Runtime.HandlerNotFound",
),
returned_exception,
)

@patch("imp.find_module")
def test_get_event_handler_build_in_conflict(self, mock_find_module):
handler_name = "sys.hello"
mock_find_module.return_value = (None, None, ("", "", C_BUILTIN))
response_handler = bootstrap._get_handler(handler_name)
def test_get_event_handler_build_in_conflict(self):
response_handler = bootstrap._get_handler("sys.hello")
with self.assertRaises(FaultException) as cm:
response_handler()
returned_exception = cm.exception
Expand Down Expand Up @@ -921,7 +922,10 @@ def test_log_error_indentation_standard_log_sink(self, mock_stdout):
)
bootstrap.log_error(err_to_log, bootstrap.StandardLogSink())

expected_logged_error = "[ERROR] ErrorType: Error message\rTraceback (most recent call last):\r\xa0\xa0line1 \r\xa0\xa0line2 \r\xa0\xa0\n"
expected_logged_error = (
"[ERROR] ErrorType: Error message\rTraceback (most recent call last):"
"\r\xa0\xa0line1 \r\xa0\xa0line2 \r\xa0\xa0\n"
)
self.assertEqual(mock_stdout.getvalue(), expected_logged_error)

def test_log_error_indentation_framed_log_sink(self):
Expand All @@ -932,7 +936,10 @@ def test_log_error_indentation_framed_log_sink(self):
)
bootstrap.log_error(err_to_log, log_sink)

expected_logged_error = "[ERROR] ErrorType: Error message\nTraceback (most recent call last):\n\xa0\xa0line1 \n\xa0\xa0line2 \n\xa0\xa0"
expected_logged_error = (
"[ERROR] ErrorType: Error message\nTraceback (most recent call last):"
"\n\xa0\xa0line1 \n\xa0\xa0line2 \n\xa0\xa0"
)

with open(temp_file.name, "rb") as f:
content = f.read()
Expand Down Expand Up @@ -964,7 +971,10 @@ def test_log_error_empty_stacktrace_line_framed_log_sink(self):
)
bootstrap.log_error(err_to_log, log_sink)

expected_logged_error = "[ERROR] ErrorType: Error message\nTraceback (most recent call last):\nline1\n\nline2"
expected_logged_error = (
"[ERROR] ErrorType: Error message\nTraceback "
"(most recent call last):\nline1\n\nline2"
)

with open(temp_file.name, "rb") as f:
content = f.read()
Expand Down Expand Up @@ -1082,11 +1092,10 @@ def test_run(self, mock_runtime_client, mock_handle_event_request):
MagicMock(),
]

with self.assertRaises(TypeError) as cm:
with self.assertRaises(TypeError):
bootstrap.run(
expected_app_root, expected_handler, expected_lambda_runtime_api_addr
)
returned_exception = cm.exception

mock_handle_event_request.assert_called_once()

Expand All @@ -1108,11 +1117,10 @@ class TestException(Exception):

mock_sys.exit.side_effect = TestException("Boom!")

with self.assertRaises(TestException) as cm:
with self.assertRaises(TestException):
bootstrap.run(
expected_app_root, expected_handler, expected_lambda_runtime_api_addr
)
returned_exception = cm.exception

mock_sys.exit.assert_called_once_with(1)

Expand Down