Skip to content

Commit 0fb9606

Browse files
divaltorVlad Vladovantonpirkerszokeasaurusrexsentrivana
authored
feat(celery): Add wrapper for Celery().send_task to support behavior as Task.apply_async (#2377)
--------- Co-authored-by: Vlad Vladov <[email protected]> Co-authored-by: Anton Pirker <[email protected]> Co-authored-by: Daniel Szoke <[email protected]> Co-authored-by: Ivana Kellyer <[email protected]>
1 parent 16d05f4 commit 0fb9606

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

Diff for: sentry_sdk/integrations/celery/__init__.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
try:
4343
from celery import VERSION as CELERY_VERSION # type: ignore
44+
from celery.app.task import Task # type: ignore
4445
from celery.app.trace import task_has_custom
4546
from celery.exceptions import ( # type: ignore
4647
Ignore,
@@ -83,6 +84,7 @@ def setup_once():
8384

8485
_patch_build_tracer()
8586
_patch_task_apply_async()
87+
_patch_celery_send_task()
8688
_patch_worker_exit()
8789
_patch_producer_publish()
8890

@@ -243,7 +245,7 @@ def __exit__(self, exc_type, exc_value, traceback):
243245
return None
244246

245247

246-
def _wrap_apply_async(f):
248+
def _wrap_task_run(f):
247249
# type: (F) -> F
248250
@wraps(f)
249251
@ensure_integration_enabled(CeleryIntegration, f)
@@ -260,14 +262,19 @@ def apply_async(*args, **kwargs):
260262
if not propagate_traces:
261263
return f(*args, **kwargs)
262264

263-
task = args[0]
265+
if isinstance(args[0], Task):
266+
task_name = args[0].name # type: str
267+
elif len(args) > 1 and isinstance(args[1], str):
268+
task_name = args[1]
269+
else:
270+
task_name = "<unknown Celery task>"
264271

265272
task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat"
266273

267274
span_mgr = (
268275
sentry_sdk.start_span(
269276
op=OP.QUEUE_SUBMIT_CELERY,
270-
description=task.name,
277+
description=task_name,
271278
origin=CeleryIntegration.origin,
272279
)
273280
if not task_started_from_beat
@@ -437,9 +444,14 @@ def sentry_build_tracer(name, task, *args, **kwargs):
437444

438445
def _patch_task_apply_async():
439446
# type: () -> None
440-
from celery.app.task import Task # type: ignore
447+
Task.apply_async = _wrap_task_run(Task.apply_async)
448+
449+
450+
def _patch_celery_send_task():
451+
# type: () -> None
452+
from celery import Celery
441453

442-
Task.apply_async = _wrap_apply_async(Task.apply_async)
454+
Celery.send_task = _wrap_task_run(Celery.send_task)
443455

444456

445457
def _patch_worker_exit():

Diff for: tests/integrations/celery/test_celery.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sentry_sdk import start_transaction, get_current_span
1111
from sentry_sdk.integrations.celery import (
1212
CeleryIntegration,
13-
_wrap_apply_async,
13+
_wrap_task_run,
1414
)
1515
from sentry_sdk.integrations.celery.beat import _get_headers
1616
from tests.conftest import ApproxDict
@@ -568,7 +568,7 @@ def dummy_function(*args, **kwargs):
568568
assert "sentry-trace" in headers
569569
assert "baggage" in headers
570570

571-
wrapped = _wrap_apply_async(dummy_function)
571+
wrapped = _wrap_task_run(dummy_function)
572572
wrapped(mock.MagicMock(), (), headers={})
573573

574574

@@ -783,3 +783,51 @@ def task(): ...
783783
assert span["origin"] == "auto.queue.celery"
784784

785785
monkeypatch.setattr(kombu.messaging.Producer, "_publish", old_publish)
786+
787+
788+
@pytest.mark.forked
789+
@mock.patch("celery.Celery.send_task")
790+
def test_send_task_wrapped(
791+
patched_send_task,
792+
sentry_init,
793+
capture_events,
794+
reset_integrations,
795+
):
796+
sentry_init(integrations=[CeleryIntegration()], enable_tracing=True)
797+
celery = Celery(__name__, broker="redis://example.com") # noqa: E231
798+
799+
events = capture_events()
800+
801+
with sentry_sdk.start_transaction(name="custom_transaction"):
802+
celery.send_task("very_creative_task_name", args=(1, 2), kwargs={"foo": "bar"})
803+
804+
(call,) = patched_send_task.call_args_list # We should have exactly one call
805+
(args, kwargs) = call
806+
807+
assert args == (celery, "very_creative_task_name")
808+
assert kwargs["args"] == (1, 2)
809+
assert kwargs["kwargs"] == {"foo": "bar"}
810+
assert set(kwargs["headers"].keys()) == {
811+
"sentry-task-enqueued-time",
812+
"sentry-trace",
813+
"baggage",
814+
"headers",
815+
}
816+
assert set(kwargs["headers"]["headers"].keys()) == {
817+
"sentry-trace",
818+
"baggage",
819+
"sentry-task-enqueued-time",
820+
}
821+
assert (
822+
kwargs["headers"]["sentry-trace"]
823+
== kwargs["headers"]["headers"]["sentry-trace"]
824+
)
825+
826+
(event,) = events # We should have exactly one event (the transaction)
827+
assert event["type"] == "transaction"
828+
assert event["transaction"] == "custom_transaction"
829+
830+
(span,) = event["spans"] # We should have exactly one span
831+
assert span["description"] == "very_creative_task_name"
832+
assert span["op"] == "queue.submit.celery"
833+
assert span["trace_id"] == kwargs["headers"]["sentry-trace"].split("-")[0]

Diff for: tox.ini

+2-1
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,9 @@ deps =
371371
celery-v5.4: Celery~=5.4.0
372372
celery-latest: Celery
373373

374-
{py3.7}-celery: importlib-metadata<5.0
375374
{py3.6,py3.7,py3.8,py3.9,py3.10,py3.11,py3.12}-celery: newrelic
375+
celery: pytest<7
376+
{py3.7}-celery: importlib-metadata<5.0
376377

377378
# Chalice
378379
chalice-v1.16: chalice~=1.16.0

0 commit comments

Comments
 (0)