Skip to content

Commit 838368c

Browse files
Add missing context kwarg to _sentry_task_factory (#2267)
* Add missing context kwargs to _sentry_task_factory * Forward context to Task * Update _sentry_task_factory type comment * Added type annotations and unit tests * Suppress linter error * Fix import error in old Python versions * Fix again linter error * Fixed all mypy errors for real * Fix tests for Python 3.7 * Add pytest.mark.forked to prevent threading test failure --------- Co-authored-by: Daniel Szoke <[email protected]> Co-authored-by: Daniel Szoke <[email protected]>
1 parent 6f49e75 commit 838368c

File tree

2 files changed

+205
-6
lines changed

2 files changed

+205
-6
lines changed

sentry_sdk/integrations/asyncio.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from typing import Any
20+
from collections.abc import Coroutine
2021

2122
from sentry_sdk._types import ExcInfo
2223

@@ -37,8 +38,8 @@ def patch_asyncio():
3738
loop = asyncio.get_running_loop()
3839
orig_task_factory = loop.get_task_factory()
3940

40-
def _sentry_task_factory(loop, coro):
41-
# type: (Any, Any) -> Any
41+
def _sentry_task_factory(loop, coro, **kwargs):
42+
# type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any]
4243

4344
async def _coro_creating_hub_and_span():
4445
# type: () -> Any
@@ -56,7 +57,7 @@ async def _coro_creating_hub_and_span():
5657

5758
# Trying to use user set task factory (if there is one)
5859
if orig_task_factory:
59-
return orig_task_factory(loop, _coro_creating_hub_and_span())
60+
return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs)
6061

6162
# The default task factory in `asyncio` does not have its own function
6263
# but is just a couple of lines in `asyncio.base_events.create_task()`
@@ -65,13 +66,13 @@ async def _coro_creating_hub_and_span():
6566
# WARNING:
6667
# If the default behavior of the task creation in asyncio changes,
6768
# this will break!
68-
task = Task(_coro_creating_hub_and_span(), loop=loop)
69+
task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs)
6970
if task._source_traceback: # type: ignore
7071
del task._source_traceback[-1] # type: ignore
7172

7273
return task
7374

74-
loop.set_task_factory(_sentry_task_factory)
75+
loop.set_task_factory(_sentry_task_factory) # type: ignore
7576
except RuntimeError:
7677
# When there is no running loop, we have nothing to patch.
7778
pass

tests/integrations/asyncio/test_asyncio_py3.py

+199-1
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
import asyncio
2+
import inspect
23
import sys
34

45
import pytest
56

67
import sentry_sdk
78
from sentry_sdk.consts import OP
8-
from sentry_sdk.integrations.asyncio import AsyncioIntegration
9+
from sentry_sdk.integrations.asyncio import AsyncioIntegration, patch_asyncio
10+
11+
try:
12+
from unittest.mock import MagicMock, patch
13+
except ImportError:
14+
from mock import MagicMock, patch
15+
16+
try:
17+
from contextvars import Context, ContextVar
18+
except ImportError:
19+
pass # All tests will be skipped with incompatible versions
920

1021

1122
minimum_python_37 = pytest.mark.skipif(
1223
sys.version_info < (3, 7), reason="Asyncio tests need Python >= 3.7"
1324
)
1425

1526

27+
minimum_python_311 = pytest.mark.skipif(
28+
sys.version_info < (3, 11),
29+
reason="Asyncio task context parameter was introduced in Python 3.11",
30+
)
31+
32+
1633
async def foo():
1734
await asyncio.sleep(0.01)
1835

@@ -33,6 +50,17 @@ def event_loop(request):
3350
loop.close()
3451

3552

53+
def get_sentry_task_factory(mock_get_running_loop):
54+
"""
55+
Patches (mocked) asyncio and gets the sentry_task_factory.
56+
"""
57+
mock_loop = mock_get_running_loop.return_value
58+
patch_asyncio()
59+
patched_factory = mock_loop.set_task_factory.call_args[0][0]
60+
61+
return patched_factory
62+
63+
3664
@minimum_python_37
3765
@pytest.mark.asyncio
3866
async def test_create_task(
@@ -170,3 +198,173 @@ async def add(a, b):
170198

171199
result = await asyncio.create_task(add(1, 2))
172200
assert result == 3, result
201+
202+
203+
@minimum_python_311
204+
@pytest.mark.asyncio
205+
async def test_task_with_context(sentry_init):
206+
"""
207+
Integration test to ensure working context parameter in Python 3.11+
208+
"""
209+
sentry_init(
210+
integrations=[
211+
AsyncioIntegration(),
212+
],
213+
)
214+
215+
var = ContextVar("var")
216+
var.set("original value")
217+
218+
async def change_value():
219+
var.set("changed value")
220+
221+
async def retrieve_value():
222+
return var.get()
223+
224+
# Create a context and run both tasks within the context
225+
ctx = Context()
226+
async with asyncio.TaskGroup() as tg:
227+
tg.create_task(change_value(), context=ctx)
228+
retrieve_task = tg.create_task(retrieve_value(), context=ctx)
229+
230+
assert retrieve_task.result() == "changed value"
231+
232+
233+
@minimum_python_37
234+
@patch("asyncio.get_running_loop")
235+
def test_patch_asyncio(mock_get_running_loop):
236+
"""
237+
Test that the patch_asyncio function will patch the task factory.
238+
"""
239+
mock_loop = mock_get_running_loop.return_value
240+
241+
patch_asyncio()
242+
243+
assert mock_loop.set_task_factory.called
244+
245+
set_task_factory_args, _ = mock_loop.set_task_factory.call_args
246+
assert len(set_task_factory_args) == 1
247+
248+
sentry_task_factory, *_ = set_task_factory_args
249+
assert callable(sentry_task_factory)
250+
251+
252+
@minimum_python_37
253+
@pytest.mark.forked
254+
@patch("asyncio.get_running_loop")
255+
@patch("sentry_sdk.integrations.asyncio.Task")
256+
def test_sentry_task_factory_no_factory(MockTask, mock_get_running_loop): # noqa: N803
257+
mock_loop = mock_get_running_loop.return_value
258+
mock_coro = MagicMock()
259+
260+
# Set the original task factory to None
261+
mock_loop.get_task_factory.return_value = None
262+
263+
# Retieve sentry task factory (since it is an inner function within patch_asyncio)
264+
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
265+
266+
# The call we are testing
267+
ret_val = sentry_task_factory(mock_loop, mock_coro)
268+
269+
assert MockTask.called
270+
assert ret_val == MockTask.return_value
271+
272+
task_args, task_kwargs = MockTask.call_args
273+
assert len(task_args) == 1
274+
275+
coro_param, *_ = task_args
276+
assert inspect.iscoroutine(coro_param)
277+
278+
assert "loop" in task_kwargs
279+
assert task_kwargs["loop"] == mock_loop
280+
281+
282+
@minimum_python_37
283+
@pytest.mark.forked
284+
@patch("asyncio.get_running_loop")
285+
def test_sentry_task_factory_with_factory(mock_get_running_loop):
286+
mock_loop = mock_get_running_loop.return_value
287+
mock_coro = MagicMock()
288+
289+
# The original task factory will be mocked out here, let's retrieve the value for later
290+
orig_task_factory = mock_loop.get_task_factory.return_value
291+
292+
# Retieve sentry task factory (since it is an inner function within patch_asyncio)
293+
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
294+
295+
# The call we are testing
296+
ret_val = sentry_task_factory(mock_loop, mock_coro)
297+
298+
assert orig_task_factory.called
299+
assert ret_val == orig_task_factory.return_value
300+
301+
task_factory_args, _ = orig_task_factory.call_args
302+
assert len(task_factory_args) == 2
303+
304+
loop_arg, coro_arg = task_factory_args
305+
assert loop_arg == mock_loop
306+
assert inspect.iscoroutine(coro_arg)
307+
308+
309+
@minimum_python_311
310+
@patch("asyncio.get_running_loop")
311+
@patch("sentry_sdk.integrations.asyncio.Task")
312+
def test_sentry_task_factory_context_no_factory(
313+
MockTask, mock_get_running_loop # noqa: N803
314+
):
315+
mock_loop = mock_get_running_loop.return_value
316+
mock_coro = MagicMock()
317+
mock_context = MagicMock()
318+
319+
# Set the original task factory to None
320+
mock_loop.get_task_factory.return_value = None
321+
322+
# Retieve sentry task factory (since it is an inner function within patch_asyncio)
323+
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
324+
325+
# The call we are testing
326+
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)
327+
328+
assert MockTask.called
329+
assert ret_val == MockTask.return_value
330+
331+
task_args, task_kwargs = MockTask.call_args
332+
assert len(task_args) == 1
333+
334+
coro_param, *_ = task_args
335+
assert inspect.iscoroutine(coro_param)
336+
337+
assert "loop" in task_kwargs
338+
assert task_kwargs["loop"] == mock_loop
339+
assert "context" in task_kwargs
340+
assert task_kwargs["context"] == mock_context
341+
342+
343+
@minimum_python_311
344+
@patch("asyncio.get_running_loop")
345+
def test_sentry_task_factory_context_with_factory(mock_get_running_loop):
346+
mock_loop = mock_get_running_loop.return_value
347+
mock_coro = MagicMock()
348+
mock_context = MagicMock()
349+
350+
# The original task factory will be mocked out here, let's retrieve the value for later
351+
orig_task_factory = mock_loop.get_task_factory.return_value
352+
353+
# Retieve sentry task factory (since it is an inner function within patch_asyncio)
354+
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
355+
356+
# The call we are testing
357+
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)
358+
359+
assert orig_task_factory.called
360+
assert ret_val == orig_task_factory.return_value
361+
362+
task_factory_args, task_factory_kwargs = orig_task_factory.call_args
363+
assert len(task_factory_args) == 2
364+
365+
loop_arg, coro_arg = task_factory_args
366+
assert loop_arg == mock_loop
367+
assert inspect.iscoroutine(coro_arg)
368+
369+
assert "context" in task_factory_kwargs
370+
assert task_factory_kwargs["context"] == mock_context

0 commit comments

Comments
 (0)