From 7b19ebdb891bcb8d51efde94235e66e8d4be8813 Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Fri, 21 Sep 2018 00:55:52 -0500 Subject: [PATCH 1/2] Refactor PublicTask into a decorator task --- readthedocs/core/utils/tasks/public.py | 40 +++++++++++----------- readthedocs/core/utils/tasks/retrieve.py | 12 +++++-- readthedocs/oauth/apps.py | 5 --- readthedocs/oauth/tasks.py | 35 ++++++++++--------- readthedocs/restapi/views/task_views.py | 29 ++++++++++------ readthedocs/rtd_tests/tests/test_celery.py | 12 +++---- 6 files changed, 70 insertions(+), 63 deletions(-) diff --git a/readthedocs/core/utils/tasks/public.py b/readthedocs/core/utils/tasks/public.py index 5aeecfa7548..4160a748590 100644 --- a/readthedocs/core/utils/tasks/public.py +++ b/readthedocs/core/utils/tasks/public.py @@ -1,12 +1,16 @@ """Celery tasks with publicly viewable status""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + from celery import Task, states from django.conf import settings -from .retrieve import TaskNotFound -from .retrieve import get_task_data - +from .retrieve import TaskNotFound, get_task_data __all__ = ( 'PublicTask', 'TaskNoPermission', 'permission_check', @@ -24,17 +28,9 @@ class PublicTask(Task): Subclasses need to define a ``run_public`` method. """ - public_name = 'unknown' - - @classmethod - def check_permission(cls, request, state, context): - """Override this method to define who can monitor this task.""" - # pylint: disable=unused-argument - return False - def get_task_data(self): """Return tuple with state to be set next and results task.""" - state = 'STARTED' + state = states.STARTED info = { 'task_name': self.name, 'context': self.request.get('permission_context', {}), @@ -66,12 +62,12 @@ def set_public_data(self, data): self.request.update(public_data=data) self.update_progress_data() - def run(self, *args, **kwargs): + def __call__(self, *args, **kwargs): error = False exception_raised = None self.set_permission_context(kwargs) try: - result = self.run_public(*args, **kwargs) + result = self.run(*args, **kwargs) except Exception as e: # With Celery 4 we lost the ability to keep our data dictionary into # ``AsyncResult.info`` when an exception was raised inside the @@ -102,9 +98,9 @@ class MyTask(PublicTask): def run_public(self, user_id): pass """ - def decorator(cls): - cls.check_permission = staticmethod(check) - return cls + def decorator(func): + func.check_permission = check + return func return decorator @@ -139,5 +135,9 @@ def get_public_task_data(request, task_id): context = info.get('context', {}) if not task.check_permission(request, state, context): raise TaskNoPermission(task_id) - public_name = task.public_name - return public_name, state, info.get('public_data', {}), info.get('error', None) + return ( + task.name, + state, + info.get('public_data', {}), + info.get('error', None), + ) diff --git a/readthedocs/core/utils/tasks/retrieve.py b/readthedocs/core/utils/tasks/retrieve.py index 9da3c581601..c96b7823706 100644 --- a/readthedocs/core/utils/tasks/retrieve.py +++ b/readthedocs/core/utils/tasks/retrieve.py @@ -1,9 +1,15 @@ """Utilities for retrieving task data.""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + +from celery import states from celery.result import AsyncResult - __all__ = ('TaskNotFound', 'get_task_data') @@ -23,7 +29,7 @@ def get_task_data(task_id): result = AsyncResult(task_id) state, info = result.state, result.info - if state == 'PENDING': + if state == states.PENDING: raise TaskNotFound(task_id) if 'task_name' not in info: raise TaskNotFound(task_id) diff --git a/readthedocs/oauth/apps.py b/readthedocs/oauth/apps.py index 57de2c82098..b8998b8f458 100644 --- a/readthedocs/oauth/apps.py +++ b/readthedocs/oauth/apps.py @@ -5,8 +5,3 @@ class OAuthConfig(AppConfig): name = 'readthedocs.oauth' - - def ready(self): - from .tasks import SyncRemoteRepositories - from readthedocs.worker import app - app.tasks.register(SyncRemoteRepositories) diff --git a/readthedocs/oauth/tasks.py b/readthedocs/oauth/tasks.py index 2ad30f8e4a6..1347ac43b5a 100644 --- a/readthedocs/oauth/tasks.py +++ b/readthedocs/oauth/tasks.py @@ -2,7 +2,11 @@ """Tasks for OAuth services.""" from __future__ import ( - absolute_import, division, print_function, unicode_literals) + absolute_import, + division, + print_function, + unicode_literals, +) import logging @@ -10,9 +14,14 @@ from django.contrib.auth.models import User from readthedocs.core.utils.tasks import ( - PublicTask, permission_check, user_id_matches) + PublicTask, + permission_check, + user_id_matches, +) from readthedocs.oauth.notifications import ( - AttachWebhookNotification, InvalidProjectWebhookNotification) + AttachWebhookNotification, + InvalidProjectWebhookNotification, +) from readthedocs.projects.models import Project from readthedocs.worker import app @@ -22,20 +31,12 @@ @permission_check(user_id_matches) -class SyncRemoteRepositories(PublicTask): - - name = __name__ + '.sync_remote_repositories' - public_name = 'sync_remote_repositories' - queue = 'web' - - def run_public(self, user_id): - user = User.objects.get(pk=user_id) - for service_cls in registry: - for service in service_cls.for_user(user): - service.sync() - - -sync_remote_repositories = SyncRemoteRepositories() +@app.task(queue='web', base=PublicTask) +def sync_remote_repositories(user_id): + user = User.objects.get(pk=user_id) + for service_cls in registry: + for service in service_cls.for_user(user): + service.sync() @app.task(queue='web') diff --git a/readthedocs/restapi/views/task_views.py b/readthedocs/restapi/views/task_views.py index f110cb2bb02..694fd7787a8 100644 --- a/readthedocs/restapi/views/task_views.py +++ b/readthedocs/restapi/views/task_views.py @@ -1,19 +1,23 @@ """Endpoints relating to task/job status, etc.""" -from __future__ import absolute_import +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) + import logging from django.core.urlresolvers import reverse +from redis import ConnectionError from rest_framework import decorators, permissions from rest_framework.renderers import JSONRenderer from rest_framework.response import Response -from redis import ConnectionError -from readthedocs.core.utils.tasks import TaskNoPermission -from readthedocs.core.utils.tasks import get_public_task_data +from readthedocs.core.utils.tasks import TaskNoPermission, get_public_task_data from readthedocs.oauth import tasks - log = logging.getLogger(__name__) @@ -43,20 +47,25 @@ def get_status_data(task_name, state, data, error=None): @decorators.renderer_classes((JSONRenderer,)) def job_status(request, task_id): try: - task_name, state, public_data, error = get_public_task_data(request, task_id) + task_name, state, public_data, error = get_public_task_data( + request, task_id + ) except (TaskNoPermission, ConnectionError): return Response( - get_status_data('unknown', 'PENDING', {})) + get_status_data('unknown', 'PENDING', {}) + ) return Response( - get_status_data(task_name, state, public_data, error)) + get_status_data(task_name, state, public_data, error) + ) @decorators.api_view(['POST']) @decorators.permission_classes((permissions.IsAuthenticated,)) @decorators.renderer_classes((JSONRenderer,)) def sync_remote_repositories(request): - result = tasks.SyncRemoteRepositories().delay( - user_id=request.user.id) + result = tasks.sync_remote_repositories.delay( + user_id=request.user.id + ) task_id = result.task_id return Response({ 'task_id': task_id, diff --git a/readthedocs/rtd_tests/tests/test_celery.py b/readthedocs/rtd_tests/tests/test_celery.py index df57d6df4f2..f806646dd7f 100644 --- a/readthedocs/rtd_tests/tests/test_celery.py +++ b/readthedocs/rtd_tests/tests/test_celery.py @@ -191,15 +191,11 @@ def test_public_task_exception(self): from readthedocs.core.utils.tasks import PublicTask from readthedocs.worker import app - class PublicTaskException(PublicTask): - name = 'public_task_exception' + @app.task(name='public_task_exception', base=PublicTask) + def public_task_exception(): + raise Exception('Something bad happened') - def run_public(self): - raise Exception('Something bad happened') - - app.tasks.register(PublicTaskException) - exception_task = PublicTaskException() - result = exception_task.delay() + result = public_task_exception.delay() # although the task risen an exception, it's success since we add the # exception into the ``info`` attributes From ebbeaa5b2c28775bb368c31ab796315729f2291c Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Fri, 21 Sep 2018 12:28:44 -0500 Subject: [PATCH 2/2] Refactor and docs --- readthedocs/core/utils/tasks/__init__.py | 1 - readthedocs/core/utils/tasks/public.py | 43 +++++++++++++++--------- readthedocs/oauth/tasks.py | 8 ++--- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/readthedocs/core/utils/tasks/__init__.py b/readthedocs/core/utils/tasks/__init__.py index 3b6b13331b7..344215036f9 100644 --- a/readthedocs/core/utils/tasks/__init__.py +++ b/readthedocs/core/utils/tasks/__init__.py @@ -3,7 +3,6 @@ from .permission_checks import user_id_matches # noqa for unused import from .public import PublicTask # noqa from .public import TaskNoPermission # noqa -from .public import permission_check # noqa from .public import get_public_task_data # noqa from .retrieve import TaskNotFound # noqa from .retrieve import get_task_data # noqa diff --git a/readthedocs/core/utils/tasks/public.py b/readthedocs/core/utils/tasks/public.py index 4160a748590..9fb2948ef71 100644 --- a/readthedocs/core/utils/tasks/public.py +++ b/readthedocs/core/utils/tasks/public.py @@ -13,8 +13,8 @@ from .retrieve import TaskNotFound, get_task_data __all__ = ( - 'PublicTask', 'TaskNoPermission', 'permission_check', - 'get_public_task_data') + 'PublicTask', 'TaskNoPermission', 'get_public_task_data' +) STATUS_UPDATES_ENABLED = not getattr(settings, 'CELERY_ALWAYS_EAGER', False) @@ -23,9 +23,15 @@ class PublicTask(Task): """ - See oauth.tasks for usage example. + Encapsulates common behaviour to expose a task publicly. + + Tasks should use this class as ``base``. And define a ``check_permission`` + property or use the ``permission_check`` decorator. - Subclasses need to define a ``run_public`` method. + The check_permission should be a function like: + function(request, state, context), and needs to return a boolean value. + + See oauth.tasks for usage example. """ def get_task_data(self): @@ -63,6 +69,7 @@ def set_public_data(self, data): self.update_progress_data() def __call__(self, *args, **kwargs): + # We override __call__ to let tasks use the run method. error = False exception_raised = None self.set_permission_context(kwargs) @@ -86,22 +93,26 @@ def __call__(self, *args, **kwargs): return info + @staticmethod + def permission_check(check): + """ + Decorator for tasks that have PublicTask as base. -def permission_check(check): - """ - Class decorator for subclasses of PublicTask to sprinkle in re-usable + .. note:: + + The decorator should be on top of the task decorator. - permission checks:: + permission checks:: - @permission_check(user_id_matches) - class MyTask(PublicTask): - def run_public(self, user_id): + @PublicTask.permission_check(user_id_matches) + @celery.task(base=PublicTask) + def my_public_task(user_id): pass - """ - def decorator(func): - func.check_permission = check - return func - return decorator + """ + def decorator(func): + func.check_permission = check + return func + return decorator class TaskNoPermission(Exception): diff --git a/readthedocs/oauth/tasks.py b/readthedocs/oauth/tasks.py index 1347ac43b5a..45b49ceac09 100644 --- a/readthedocs/oauth/tasks.py +++ b/readthedocs/oauth/tasks.py @@ -13,11 +13,7 @@ from allauth.socialaccount.providers import registry as allauth_registry from django.contrib.auth.models import User -from readthedocs.core.utils.tasks import ( - PublicTask, - permission_check, - user_id_matches, -) +from readthedocs.core.utils.tasks import PublicTask, user_id_matches from readthedocs.oauth.notifications import ( AttachWebhookNotification, InvalidProjectWebhookNotification, @@ -30,7 +26,7 @@ log = logging.getLogger(__name__) -@permission_check(user_id_matches) +@PublicTask.permission_check(user_id_matches) @app.task(queue='web', base=PublicTask) def sync_remote_repositories(user_id): user = User.objects.get(pk=user_id)