Skip to content

GGV2 client: add exception logging so that errors are always visible,… #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion awsiot/greengrasscoreipc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ class GreengrassCoreIPCClient(rpc.Client):
"""
Client for the GreengrassCoreIPC service.
There is a new V2 client available for testing in developer preview.
See the GreengrassCoreIPCClientV2 class.
See the GreengrassCoreIPCClientV2 class in the clientv2 subpackage.

Args:
connection: Connection that this client will use.
Expand Down
20 changes: 19 additions & 1 deletion awsiot/greengrasscoreipc/clientv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, client: typing.Optional[GreengrassCoreIPCClient] = None,
import awsiot.greengrasscoreipc
client = awsiot.greengrasscoreipc.connect()
self.client = client
if executor == True:
if executor is True:
executor = concurrent.futures.ThreadPoolExecutor()
self.executor = executor

Expand Down Expand Up @@ -67,20 +67,38 @@ def callback(*args, **kwargs):
future1.add_done_callback(callback)
return future2

@staticmethod
def __handle_error():
import sys
import traceback
traceback.print_exc(file=sys.stderr)

def __wrap_error(self, func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
self.__handle_error()
raise e
return wrapper

def __create_stream_handler(real_self, operation, on_stream_event, on_stream_error, on_stream_closed):
stream_handler_type = type(operation + 'Handler', (getattr(client, operation + "StreamHandler"),), {})
if on_stream_event is not None:
on_stream_event = real_self.__wrap_error(on_stream_event)
def handler(self, event):
if real_self.executor is not None:
real_self.executor.submit(on_stream_event, event)
else:
on_stream_event(event)
setattr(stream_handler_type, "on_stream_event", handler)
if on_stream_error is not None:
on_stream_error = real_self.__wrap_error(on_stream_error)
def handler(self, error):
return on_stream_error(error)
setattr(stream_handler_type, "on_stream_error", handler)
if on_stream_closed is not None:
on_stream_closed = real_self.__wrap_error(on_stream_closed)
def handler(self):
if real_self.executor is not None:
real_self.executor.submit(on_stream_closed)
Expand Down
14 changes: 14 additions & 0 deletions test/test_ggv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import threading
from unittest import TestCase
from unittest.mock import patch
import io
import contextlib

from awsiot.greengrasscoreipc.client import SubscribeToTopicStreamHandler
from awsiot.greengrasscoreipc.model import CreateLocalDeploymentResponse, SubscribeToTopicResponse, \
Expand Down Expand Up @@ -83,3 +85,15 @@ def on_stream_event(self, event):

self.assertEqual("xyz".encode("utf-8"), subscription_fut.result(TIMEOUT).binary_message.message)
self.assertEqual(threading.get_ident(), thread_id_fut.result(TIMEOUT))

# Verify we nicely print errors in user-provided handler methods
def on_stream_event(r):
raise ValueError("Broken!")

c.subscribe_to_topic(topic="abc", on_stream_event=on_stream_event)
sub_handler = mock_client.new_subscribe_to_topic.call_args[0][0]
f = io.StringIO()
with contextlib.redirect_stderr(f):
self.assertRaises(ValueError, lambda: sub_handler.on_stream_event(
SubscriptionResponseMessage(binary_message=BinaryMessage(message="xyz"))))
self.assertIn("ValueError: Broken!", f.getvalue())