diff --git a/readthedocs/core/utils/tasks/__init__.py b/readthedocs/core/utils/tasks/__init__.py index 3b6b13331b7..344215036f9 100644 --- a/readthedocs/core/utils/tasks/__init__.py +++ b/readthedocs/core/utils/tasks/__init__.py @@ -3,7 +3,6 @@ from .permission_checks import user_id_matches # noqa for unused import from .public import PublicTask # noqa from .public import TaskNoPermission # noqa -from .public import permission_check # noqa from .public import get_public_task_data # noqa from .retrieve import TaskNotFound # noqa from .retrieve import get_task_data # noqa diff --git a/readthedocs/core/utils/tasks/public.py b/readthedocs/core/utils/tasks/public.py index 5aeecfa7548..9fb2948ef71 100644 --- a/readthedocs/core/utils/tasks/public.py +++ b/readthedocs/core/utils/tasks/public.py @@ -1,16 +1,20 @@ """Celery tasks with publicly viewable status""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + from celery import Task, states from django.conf import settings -from .retrieve import TaskNotFound -from .retrieve import get_task_data - +from .retrieve import TaskNotFound, get_task_data __all__ = ( - 'PublicTask', 'TaskNoPermission', 'permission_check', - 'get_public_task_data') + 'PublicTask', 'TaskNoPermission', 'get_public_task_data' +) STATUS_UPDATES_ENABLED = not getattr(settings, 'CELERY_ALWAYS_EAGER', False) @@ -19,22 +23,20 @@ class PublicTask(Task): """ - See oauth.tasks for usage example. + Encapsulates common behaviour to expose a task publicly. - Subclasses need to define a ``run_public`` method. - """ + Tasks should use this class as ``base``. And define a ``check_permission`` + property or use the ``permission_check`` decorator. - public_name = 'unknown' + The check_permission should be a function like: + function(request, state, context), and needs to return a boolean value. - @classmethod - def check_permission(cls, request, state, context): - """Override this method to define who can monitor this task.""" - # pylint: disable=unused-argument - return False + See oauth.tasks for usage example. + """ def get_task_data(self): """Return tuple with state to be set next and results task.""" - state = 'STARTED' + state = states.STARTED info = { 'task_name': self.name, 'context': self.request.get('permission_context', {}), @@ -66,12 +68,13 @@ def set_public_data(self, data): self.request.update(public_data=data) self.update_progress_data() - def run(self, *args, **kwargs): + def __call__(self, *args, **kwargs): + # We override __call__ to let tasks use the run method. error = False exception_raised = None self.set_permission_context(kwargs) try: - result = self.run_public(*args, **kwargs) + result = self.run(*args, **kwargs) except Exception as e: # With Celery 4 we lost the ability to keep our data dictionary into # ``AsyncResult.info`` when an exception was raised inside the @@ -90,22 +93,26 @@ def run(self, *args, **kwargs): return info + @staticmethod + def permission_check(check): + """ + Decorator for tasks that have PublicTask as base. -def permission_check(check): - """ - Class decorator for subclasses of PublicTask to sprinkle in re-usable + .. note:: + + The decorator should be on top of the task decorator. - permission checks:: + permission checks:: - @permission_check(user_id_matches) - class MyTask(PublicTask): - def run_public(self, user_id): + @PublicTask.permission_check(user_id_matches) + @celery.task(base=PublicTask) + def my_public_task(user_id): pass - """ - def decorator(cls): - cls.check_permission = staticmethod(check) - return cls - return decorator + """ + def decorator(func): + func.check_permission = check + return func + return decorator class TaskNoPermission(Exception): @@ -139,5 +146,9 @@ def get_public_task_data(request, task_id): context = info.get('context', {}) if not task.check_permission(request, state, context): raise TaskNoPermission(task_id) - public_name = task.public_name - return public_name, state, info.get('public_data', {}), info.get('error', None) + return ( + task.name, + state, + info.get('public_data', {}), + info.get('error', None), + ) diff --git a/readthedocs/core/utils/tasks/retrieve.py b/readthedocs/core/utils/tasks/retrieve.py index 9da3c581601..c96b7823706 100644 --- a/readthedocs/core/utils/tasks/retrieve.py +++ b/readthedocs/core/utils/tasks/retrieve.py @@ -1,9 +1,15 @@ """Utilities for retrieving task data.""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + +from celery import states from celery.result import AsyncResult - __all__ = ('TaskNotFound', 'get_task_data') @@ -23,7 +29,7 @@ def get_task_data(task_id): result = AsyncResult(task_id) state, info = result.state, result.info - if state == 'PENDING': + if state == states.PENDING: raise TaskNotFound(task_id) if 'task_name' not in info: raise TaskNotFound(task_id) diff --git a/readthedocs/oauth/apps.py b/readthedocs/oauth/apps.py index 57de2c82098..b8998b8f458 100644 --- a/readthedocs/oauth/apps.py +++ b/readthedocs/oauth/apps.py @@ -5,8 +5,3 @@ class OAuthConfig(AppConfig): name = 'readthedocs.oauth' - - def ready(self): - from .tasks import SyncRemoteRepositories - from readthedocs.worker import app - app.tasks.register(SyncRemoteRepositories) diff --git a/readthedocs/oauth/tasks.py b/readthedocs/oauth/tasks.py index 2ad30f8e4a6..45b49ceac09 100644 --- a/readthedocs/oauth/tasks.py +++ b/readthedocs/oauth/tasks.py @@ -2,17 +2,22 @@ """Tasks for OAuth services.""" from __future__ import ( - absolute_import, division, print_function, unicode_literals) + absolute_import, + division, + print_function, + unicode_literals, +) import logging from allauth.socialaccount.providers import registry as allauth_registry from django.contrib.auth.models import User -from readthedocs.core.utils.tasks import ( - PublicTask, permission_check, user_id_matches) +from readthedocs.core.utils.tasks import PublicTask, user_id_matches from readthedocs.oauth.notifications import ( - AttachWebhookNotification, InvalidProjectWebhookNotification) + AttachWebhookNotification, + InvalidProjectWebhookNotification, +) from readthedocs.projects.models import Project from readthedocs.worker import app @@ -21,21 +26,13 @@ log = logging.getLogger(__name__) -@permission_check(user_id_matches) -class SyncRemoteRepositories(PublicTask): - - name = __name__ + '.sync_remote_repositories' - public_name = 'sync_remote_repositories' - queue = 'web' - - def run_public(self, user_id): - user = User.objects.get(pk=user_id) - for service_cls in registry: - for service in service_cls.for_user(user): - service.sync() - - -sync_remote_repositories = SyncRemoteRepositories() +@PublicTask.permission_check(user_id_matches) +@app.task(queue='web', base=PublicTask) +def sync_remote_repositories(user_id): + user = User.objects.get(pk=user_id) + for service_cls in registry: + for service in service_cls.for_user(user): + service.sync() @app.task(queue='web') diff --git a/readthedocs/restapi/views/task_views.py b/readthedocs/restapi/views/task_views.py index f110cb2bb02..694fd7787a8 100644 --- a/readthedocs/restapi/views/task_views.py +++ b/readthedocs/restapi/views/task_views.py @@ -1,19 +1,23 @@ """Endpoints relating to task/job status, etc.""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + import logging from django.core.urlresolvers import reverse +from redis import ConnectionError from rest_framework import decorators, permissions from rest_framework.renderers import JSONRenderer from rest_framework.response import Response -from redis import ConnectionError -from readthedocs.core.utils.tasks import TaskNoPermission -from readthedocs.core.utils.tasks import get_public_task_data +from readthedocs.core.utils.tasks import TaskNoPermission, get_public_task_data from readthedocs.oauth import tasks - log = logging.getLogger(__name__) @@ -43,20 +47,25 @@ def get_status_data(task_name, state, data, error=None): @decorators.renderer_classes((JSONRenderer,)) def job_status(request, task_id): try: - task_name, state, public_data, error = get_public_task_data(request, task_id) + task_name, state, public_data, error = get_public_task_data( + request, task_id + ) except (TaskNoPermission, ConnectionError): return Response( - get_status_data('unknown', 'PENDING', {})) + get_status_data('unknown', 'PENDING', {}) + ) return Response( - get_status_data(task_name, state, public_data, error)) + get_status_data(task_name, state, public_data, error) + ) @decorators.api_view(['POST']) @decorators.permission_classes((permissions.IsAuthenticated,)) @decorators.renderer_classes((JSONRenderer,)) def sync_remote_repositories(request): - result = tasks.SyncRemoteRepositories().delay( - user_id=request.user.id) + result = tasks.sync_remote_repositories.delay( + user_id=request.user.id + ) task_id = result.task_id return Response({ 'task_id': task_id, diff --git a/readthedocs/rtd_tests/tests/test_celery.py b/readthedocs/rtd_tests/tests/test_celery.py index df57d6df4f2..f806646dd7f 100644 --- a/readthedocs/rtd_tests/tests/test_celery.py +++ b/readthedocs/rtd_tests/tests/test_celery.py @@ -191,15 +191,11 @@ def test_public_task_exception(self): from readthedocs.core.utils.tasks import PublicTask from readthedocs.worker import app - class PublicTaskException(PublicTask): - name = 'public_task_exception' + @app.task(name='public_task_exception', base=PublicTask) + def public_task_exception(): + raise Exception('Something bad happened') - def run_public(self): - raise Exception('Something bad happened') - - app.tasks.register(PublicTaskException) - exception_task = PublicTaskException() - result = exception_task.delay() + result = public_task_exception.delay() # although the task risen an exception, it's success since we add the # exception into the ``info`` attributes