Skip to content

Commit a8bd00a

Browse files
stsewdagjohnson
authored andcommitted
Refactor PublicTask into a decorator task (#4656)
* Refactor PublicTask into a decorator task * Refactor and docs
1 parent c669284 commit a8bd00a

File tree

7 files changed

+91
-78
lines changed

7 files changed

+91
-78
lines changed

readthedocs/core/utils/tasks/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from .permission_checks import user_id_matches # noqa for unused import
44
from .public import PublicTask # noqa
55
from .public import TaskNoPermission # noqa
6-
from .public import permission_check # noqa
76
from .public import get_public_task_data # noqa
87
from .retrieve import TaskNotFound # noqa
98
from .retrieve import get_task_data # noqa

readthedocs/core/utils/tasks/public.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
"""Celery tasks with publicly viewable status"""
22

3-
from __future__ import absolute_import
3+
from __future__ import (
4+
absolute_import,
5+
division,
6+
print_function,
7+
unicode_literals,
8+
)
9+
410
from celery import Task, states
511
from django.conf import settings
612

7-
from .retrieve import TaskNotFound
8-
from .retrieve import get_task_data
9-
13+
from .retrieve import TaskNotFound, get_task_data
1014

1115
__all__ = (
12-
'PublicTask', 'TaskNoPermission', 'permission_check',
13-
'get_public_task_data')
16+
'PublicTask', 'TaskNoPermission', 'get_public_task_data'
17+
)
1418

1519

1620
STATUS_UPDATES_ENABLED = not getattr(settings, 'CELERY_ALWAYS_EAGER', False)
@@ -19,22 +23,20 @@
1923
class PublicTask(Task):
2024

2125
"""
22-
See oauth.tasks for usage example.
26+
Encapsulates common behaviour to expose a task publicly.
2327
24-
Subclasses need to define a ``run_public`` method.
25-
"""
28+
Tasks should use this class as ``base``. And define a ``check_permission``
29+
property or use the ``permission_check`` decorator.
2630
27-
public_name = 'unknown'
31+
The check_permission should be a function like:
32+
function(request, state, context), and needs to return a boolean value.
2833
29-
@classmethod
30-
def check_permission(cls, request, state, context):
31-
"""Override this method to define who can monitor this task."""
32-
# pylint: disable=unused-argument
33-
return False
34+
See oauth.tasks for usage example.
35+
"""
3436

3537
def get_task_data(self):
3638
"""Return tuple with state to be set next and results task."""
37-
state = 'STARTED'
39+
state = states.STARTED
3840
info = {
3941
'task_name': self.name,
4042
'context': self.request.get('permission_context', {}),
@@ -66,12 +68,13 @@ def set_public_data(self, data):
6668
self.request.update(public_data=data)
6769
self.update_progress_data()
6870

69-
def run(self, *args, **kwargs):
71+
def __call__(self, *args, **kwargs):
72+
# We override __call__ to let tasks use the run method.
7073
error = False
7174
exception_raised = None
7275
self.set_permission_context(kwargs)
7376
try:
74-
result = self.run_public(*args, **kwargs)
77+
result = self.run(*args, **kwargs)
7578
except Exception as e:
7679
# With Celery 4 we lost the ability to keep our data dictionary into
7780
# ``AsyncResult.info`` when an exception was raised inside the
@@ -90,22 +93,26 @@ def run(self, *args, **kwargs):
9093

9194
return info
9295

96+
@staticmethod
97+
def permission_check(check):
98+
"""
99+
Decorator for tasks that have PublicTask as base.
93100
94-
def permission_check(check):
95-
"""
96-
Class decorator for subclasses of PublicTask to sprinkle in re-usable
101+
.. note::
102+
103+
The decorator should be on top of the task decorator.
97104
98-
permission checks::
105+
permission checks::
99106
100-
@permission_check(user_id_matches)
101-
class MyTask(PublicTask):
102-
def run_public(self, user_id):
107+
@PublicTask.permission_check(user_id_matches)
108+
@celery.task(base=PublicTask)
109+
def my_public_task(user_id):
103110
pass
104-
"""
105-
def decorator(cls):
106-
cls.check_permission = staticmethod(check)
107-
return cls
108-
return decorator
111+
"""
112+
def decorator(func):
113+
func.check_permission = check
114+
return func
115+
return decorator
109116

110117

111118
class TaskNoPermission(Exception):
@@ -139,5 +146,9 @@ def get_public_task_data(request, task_id):
139146
context = info.get('context', {})
140147
if not task.check_permission(request, state, context):
141148
raise TaskNoPermission(task_id)
142-
public_name = task.public_name
143-
return public_name, state, info.get('public_data', {}), info.get('error', None)
149+
return (
150+
task.name,
151+
state,
152+
info.get('public_data', {}),
153+
info.get('error', None),
154+
)

readthedocs/core/utils/tasks/retrieve.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
"""Utilities for retrieving task data."""
22

3-
from __future__ import absolute_import
3+
from __future__ import (
4+
absolute_import,
5+
division,
6+
print_function,
7+
unicode_literals,
8+
)
9+
10+
from celery import states
411
from celery.result import AsyncResult
512

6-
713
__all__ = ('TaskNotFound', 'get_task_data')
814

915

@@ -23,7 +29,7 @@ def get_task_data(task_id):
2329

2430
result = AsyncResult(task_id)
2531
state, info = result.state, result.info
26-
if state == 'PENDING':
32+
if state == states.PENDING:
2733
raise TaskNotFound(task_id)
2834
if 'task_name' not in info:
2935
raise TaskNotFound(task_id)

readthedocs/oauth/apps.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,3 @@
55

66
class OAuthConfig(AppConfig):
77
name = 'readthedocs.oauth'
8-
9-
def ready(self):
10-
from .tasks import SyncRemoteRepositories
11-
from readthedocs.worker import app
12-
app.tasks.register(SyncRemoteRepositories)

readthedocs/oauth/tasks.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
"""Tasks for OAuth services."""
33

44
from __future__ import (
5-
absolute_import, division, print_function, unicode_literals)
5+
absolute_import,
6+
division,
7+
print_function,
8+
unicode_literals,
9+
)
610

711
import logging
812

913
from allauth.socialaccount.providers import registry as allauth_registry
1014
from django.contrib.auth.models import User
1115

12-
from readthedocs.core.utils.tasks import (
13-
PublicTask, permission_check, user_id_matches)
16+
from readthedocs.core.utils.tasks import PublicTask, user_id_matches
1417
from readthedocs.oauth.notifications import (
15-
AttachWebhookNotification, InvalidProjectWebhookNotification)
18+
AttachWebhookNotification,
19+
InvalidProjectWebhookNotification,
20+
)
1621
from readthedocs.projects.models import Project
1722
from readthedocs.worker import app
1823

@@ -21,21 +26,13 @@
2126
log = logging.getLogger(__name__)
2227

2328

24-
@permission_check(user_id_matches)
25-
class SyncRemoteRepositories(PublicTask):
26-
27-
name = __name__ + '.sync_remote_repositories'
28-
public_name = 'sync_remote_repositories'
29-
queue = 'web'
30-
31-
def run_public(self, user_id):
32-
user = User.objects.get(pk=user_id)
33-
for service_cls in registry:
34-
for service in service_cls.for_user(user):
35-
service.sync()
36-
37-
38-
sync_remote_repositories = SyncRemoteRepositories()
29+
@PublicTask.permission_check(user_id_matches)
30+
@app.task(queue='web', base=PublicTask)
31+
def sync_remote_repositories(user_id):
32+
user = User.objects.get(pk=user_id)
33+
for service_cls in registry:
34+
for service in service_cls.for_user(user):
35+
service.sync()
3936

4037

4138
@app.task(queue='web')

readthedocs/restapi/views/task_views.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
"""Endpoints relating to task/job status, etc."""
22

3-
from __future__ import absolute_import
3+
from __future__ import (
4+
absolute_import,
5+
division,
6+
print_function,
7+
unicode_literals,
8+
)
9+
410
import logging
511

612
from django.core.urlresolvers import reverse
13+
from redis import ConnectionError
714
from rest_framework import decorators, permissions
815
from rest_framework.renderers import JSONRenderer
916
from rest_framework.response import Response
10-
from redis import ConnectionError
1117

12-
from readthedocs.core.utils.tasks import TaskNoPermission
13-
from readthedocs.core.utils.tasks import get_public_task_data
18+
from readthedocs.core.utils.tasks import TaskNoPermission, get_public_task_data
1419
from readthedocs.oauth import tasks
1520

16-
1721
log = logging.getLogger(__name__)
1822

1923

@@ -43,20 +47,25 @@ def get_status_data(task_name, state, data, error=None):
4347
@decorators.renderer_classes((JSONRenderer,))
4448
def job_status(request, task_id):
4549
try:
46-
task_name, state, public_data, error = get_public_task_data(request, task_id)
50+
task_name, state, public_data, error = get_public_task_data(
51+
request, task_id
52+
)
4753
except (TaskNoPermission, ConnectionError):
4854
return Response(
49-
get_status_data('unknown', 'PENDING', {}))
55+
get_status_data('unknown', 'PENDING', {})
56+
)
5057
return Response(
51-
get_status_data(task_name, state, public_data, error))
58+
get_status_data(task_name, state, public_data, error)
59+
)
5260

5361

5462
@decorators.api_view(['POST'])
5563
@decorators.permission_classes((permissions.IsAuthenticated,))
5664
@decorators.renderer_classes((JSONRenderer,))
5765
def sync_remote_repositories(request):
58-
result = tasks.SyncRemoteRepositories().delay(
59-
user_id=request.user.id)
66+
result = tasks.sync_remote_repositories.delay(
67+
user_id=request.user.id
68+
)
6069
task_id = result.task_id
6170
return Response({
6271
'task_id': task_id,

readthedocs/rtd_tests/tests/test_celery.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,11 @@ def test_public_task_exception(self):
191191
from readthedocs.core.utils.tasks import PublicTask
192192
from readthedocs.worker import app
193193

194-
class PublicTaskException(PublicTask):
195-
name = 'public_task_exception'
194+
@app.task(name='public_task_exception', base=PublicTask)
195+
def public_task_exception():
196+
raise Exception('Something bad happened')
196197

197-
def run_public(self):
198-
raise Exception('Something bad happened')
199-
200-
app.tasks.register(PublicTaskException)
201-
exception_task = PublicTaskException()
202-
result = exception_task.delay()
198+
result = public_task_exception.delay()
203199

204200
# although the task risen an exception, it's success since we add the
205201
# exception into the ``info`` attributes

0 commit comments

Comments
 (0)