Skip to content

Commit 5fb42ed

Browse files
author
Alin RADU
committed
Fix @metric_scope for generator and async generator functions
1 parent 4c36304 commit 5fb42ed

File tree

2 files changed

+82
-10
lines changed

2 files changed

+82
-10
lines changed

aws_embedded_metrics/metric_scope/__init__.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,70 @@
1818

1919

2020
def metric_scope(fn): # type: ignore
21+
if inspect.isasyncgenfunction(fn):
22+
@wraps(fn)
23+
async def async_gen_wrapper(*args, **kwargs): # type: ignore
24+
logger = create_metrics_logger()
25+
if "metrics" in inspect.signature(fn).parameters:
26+
kwargs["metrics"] = logger
27+
28+
try:
29+
fn_gen = fn(*args, **kwargs)
30+
while True:
31+
result = await fn_gen.__anext__()
32+
await logger.flush()
33+
yield result
34+
except Exception as ex:
35+
await logger.flush()
36+
if not isinstance(ex, StopIteration):
37+
raise
38+
39+
return async_gen_wrapper
40+
41+
elif inspect.isgeneratorfunction(fn):
42+
@wraps(fn)
43+
def gen_wrapper(*args, **kwargs): # type: ignore
44+
logger = create_metrics_logger()
45+
if "metrics" in inspect.signature(fn).parameters:
46+
kwargs["metrics"] = logger
47+
48+
try:
49+
fn_gen = fn(*args, **kwargs)
50+
while True:
51+
result = next(fn_gen)
52+
asyncio.run(logger.flush())
53+
yield result
54+
except Exception as ex:
55+
asyncio.run(logger.flush())
56+
if not isinstance(ex, StopIteration):
57+
raise
2158

22-
if asyncio.iscoroutinefunction(fn):
59+
return gen_wrapper
2360

61+
elif asyncio.iscoroutinefunction(fn):
2462
@wraps(fn)
25-
async def wrapper(*args, **kwargs): # type: ignore
63+
async def async_wrapper(*args, **kwargs): # type: ignore
2664
logger = create_metrics_logger()
2765
if "metrics" in inspect.signature(fn).parameters:
2866
kwargs["metrics"] = logger
67+
2968
try:
3069
return await fn(*args, **kwargs)
31-
except Exception as e:
32-
raise e
3370
finally:
3471
await logger.flush()
3572

36-
return wrapper
37-
else:
73+
return async_wrapper
3874

75+
else:
3976
@wraps(fn)
4077
def wrapper(*args, **kwargs): # type: ignore
4178
logger = create_metrics_logger()
4279
if "metrics" in inspect.signature(fn).parameters:
4380
kwargs["metrics"] = logger
81+
4482
try:
4583
return fn(*args, **kwargs)
46-
except Exception as e:
47-
raise e
4884
finally:
49-
loop = asyncio.get_event_loop()
50-
loop.run_until_complete(logger.flush())
85+
asyncio.run(logger.flush())
5186

5287
return wrapper

tests/metric_scope/test_metric_scope.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,43 @@ def my_handler(metrics):
168168
actual_timestamp_second = int(round(logger.context.meta["Timestamp"] / 1000))
169169
assert expected_timestamp_second == actual_timestamp_second
170170

171+
172+
def test_sync_scope_iterates_generator(mock_logger):
173+
expected_results = [1, 2]
174+
175+
@metric_scope
176+
def my_handler():
177+
yield from expected_results
178+
raise Exception("test exception")
179+
180+
actual_results = []
181+
with pytest.raises(Exception, match="test exception"):
182+
for result in my_handler():
183+
actual_results.append(result)
184+
185+
assert actual_results == expected_results
186+
assert InvocationTracker.invocations == 3
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_async_scope_iterates_async_generator(mock_logger):
191+
expected_results = [1, 2]
192+
193+
@metric_scope
194+
async def my_handler():
195+
for item in expected_results:
196+
yield item
197+
await asyncio.sleep(1)
198+
raise Exception("test exception")
199+
200+
actual_results = []
201+
with pytest.raises(Exception, match="test exception"):
202+
async for result in my_handler():
203+
actual_results.append(result)
204+
205+
assert actual_results == expected_results
206+
assert InvocationTracker.invocations == 3
207+
171208
# Test helpers
172209

173210

0 commit comments

Comments
 (0)