From edcaceb4a682fde5e733cffc64b214e76fed54a6 Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Thu, 21 May 2020 15:48:04 -0500 Subject: [PATCH 1/3] Search: refactor API to not emulate a Django queryset The recommended way of using pagination over a custom object is to manage the class ourselves. I tried rely on most of the defaults of the DRF's pagination class. Closes https://github.com/readthedocs/readthedocs.org/issues/5235 --- readthedocs/search/api.py | 249 ++++++++++++++++++--------- readthedocs/search/faceted_search.py | 4 +- readthedocs/search/tests/test_api.py | 4 +- 3 files changed, 171 insertions(+), 86 deletions(-) diff --git a/readthedocs/search/api.py b/readthedocs/search/api.py index d36a1ab8e5d..8cc4e03d3a8 100644 --- a/readthedocs/search/api.py +++ b/readthedocs/search/api.py @@ -1,28 +1,104 @@ import itertools import logging import re +from functools import namedtuple +from math import ceil from django.shortcuts import get_object_or_404 from django.utils import timezone -from rest_framework import generics, serializers -from rest_framework.exceptions import ValidationError +from django.utils.translation import ugettext as _ +from rest_framework import serializers +from rest_framework.exceptions import NotFound, ValidationError +from rest_framework.generics import GenericAPIView from rest_framework.pagination import PageNumberPagination +from rest_framework.response import Response +from rest_framework.utils.urls import remove_query_param, replace_query_param from readthedocs.api.v2.permissions import IsAuthorizedToViewVersion from readthedocs.builds.models import Version from readthedocs.projects.constants import MKDOCS, SPHINX_HTMLDIR -from readthedocs.projects.models import HTMLFile, Project +from readthedocs.projects.models import Project from readthedocs.search import tasks, utils from readthedocs.search.faceted_search import PageSearch log = logging.getLogger(__name__) +class PaginatorPage: + """ + Mimics the result from a paginator. + + By using this class, we avoid having to override a lot of methods + of `PageNumberPagination` to make it work with the ES DSL object. + """ + + def __init__(self, page_number, total_pages, count): + self.number = page_number + Paginator = namedtuple('Paginator', ['num_pages', 'count']) + self.paginator = Paginator(total_pages, count) + + def has_next(self): + return self.number < self.paginator.num_pages + + def has_previous(self): + return self.number > 0 + + def next_page_number(self): + return self.number + 1 + + def previous_page_number(self): + return self.number - 1 + + class SearchPagination(PageNumberPagination): + """Paginator for the results of PageSearch.""" + page_size = 50 page_size_query_param = 'page_size' max_page_size = 100 + def paginate_queryset(self, queryset, request, view=None): + """Override to get the paginated result from the ES queryset.""" + # Needed for other methods of this class. + self.request = request + + page_size = self.get_page_size(request) + + total_count = 0 + total_pages = 1 + if queryset: + total_count = queryset.total_count() + hits = max(1, total_count) + total_pages = ceil(hits / page_size) + + page_number = request.query_params.get(self.page_query_param, 1) + if page_number in self.last_page_strings: + page_number = total_pages + + if page_number <= 0: + msg = self.invalid_page_message.format( + page_number=page_number, + message=_("Invalid page"), + ) + raise NotFound(msg) + + if total_pages > 1 and self.template is not None: + # The browsable API should display pagination controls. + self.display_page_controls = True + + start = (page_number - 1) * page_size + end = page_number * page_size + result = list(queryset[start:end]) + + # Needed for other methods of this class. + self.page = PaginatorPage( + page_number=page_number, + total_pages=total_pages, + count=total_count, + ) + + return result + class PageSearchSerializer(serializers.Serializer): project = serializers.CharField() @@ -75,12 +151,12 @@ def get_inner_hits(self, obj): return sorted_results -class PageSearchAPIView(generics.ListAPIView): - +class PageSearchAPIView(GenericAPIView): """ Main entry point to perform a search using Elasticsearch. Required query params: + - q (search term) - project - version @@ -91,6 +167,7 @@ class PageSearchAPIView(generics.ListAPIView): are called many times, so a basic cache is implemented. """ + http_method_names = ['get'] permission_classes = [IsAuthorizedToViewVersion] pagination_class = SearchPagination serializer_class = PageSearchSerializer @@ -121,39 +198,7 @@ def _get_version(self): return version - def get_queryset(self): - """ - Return Elasticsearch DSL Search object instead of Django Queryset. - - Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern. - So for searching, its possible to return ``Search`` object instead of queryset. - The ``filter_backends`` and ``pagination_class`` is compatible with ``Search`` - """ - # Validate all the required params are there - self.validate_query_params() - query = self.request.query_params.get('q', '') - filters = {} - filters['project'] = [p.slug for p in self.get_all_projects()] - filters['version'] = self._get_version().slug - - # Check to avoid searching all projects in case these filters are empty. - if not filters['project']: - log.info("Unable to find a project to search") - return HTMLFile.objects.none() - if not filters['version']: - log.info("Unable to find a version to search") - return HTMLFile.objects.none() - - queryset = PageSearch( - query=query, - filters=filters, - user=self.request.user, - # We use a permission class to control authorization - filter_by_user=False, - ) - return queryset - - def validate_query_params(self): + def _validate_query_params(self): """ Validate all required query params are passed on the request. @@ -163,47 +208,16 @@ def validate_query_params(self): :raises: ValidationError if one of them is missing. """ - required_query_params = {'q', 'project', 'version'} # python `set` literal is `{}` + errors = {} + required_query_params = {'q', 'project', 'version'} request_params = set(self.request.query_params.keys()) missing_params = required_query_params - request_params - if missing_params: - errors = {} - for param in missing_params: - errors[param] = ["This query param is required"] - + for param in missing_params: + errors[param] = [_("This query param is required")] + if errors: raise ValidationError(errors) - def get_serializer_context(self): - context = super().get_serializer_context() - context['projects_data'] = self.get_all_projects_data() - return context - - def get_all_projects(self): - """ - Return a list of the project itself and all its subprojects the user has permissions over. - - :rtype: list - """ - main_version = self._get_version() - main_project = self._get_project() - - all_projects = [main_project] - - subprojects = Project.objects.filter( - superprojects__parent_id=main_project.id, - ) - for project in subprojects: - version = ( - Version.internal - .public(user=self.request.user, project=project, include_hidden=False) - .filter(slug=main_version.slug) - .first() - ) - if version: - all_projects.append(version.project) - return all_projects - - def get_all_projects_data(self): + def _get_all_projects_data(self): """ Return a dict containing the project slug and its version URL and version's doctype. @@ -224,7 +238,7 @@ def get_all_projects_data(self): :rtype: dict """ - all_projects = self.get_all_projects() + all_projects = self._get_all_projects() version_slug = self._get_version().slug project_urls = {} for project in all_projects: @@ -242,20 +256,41 @@ def get_all_projects_data(self): } return projects_data - def list(self, request, *args, **kwargs): - """Overriding ``list`` method to record query in database.""" + def _get_all_projects(self): + """ + Returns a list of the project itself and all its subprojects the user has permissions over. + + :rtype: list + """ + main_version = self._get_version() + main_project = self._get_project() - response = super().list(request, *args, **kwargs) + all_projects = [main_project] + subprojects = Project.objects.filter( + superprojects__parent_id=main_project.id, + ) + for project in subprojects: + version = ( + Version.internal + .public(user=self.request.user, project=project, include_hidden=False) + .filter(slug=main_version.slug) + .first() + ) + if version: + all_projects.append(version.project) + return all_projects + + def _record_query(self, response): project_slug = self._get_project().slug version_slug = self._get_version().slug total_results = response.data.get('count', 0) time = timezone.now() - query = self.request.query_params.get('q', '') + query = self.request.query_params['q'] query = query.lower().strip() - # record the search query with a celery task + # Record the query with a celery task tasks.record_search_query.delay( project_slug, version_slug, @@ -264,4 +299,54 @@ def list(self, request, *args, **kwargs): time.isoformat(), ) - return response + def get_queryset(self): + """ + Returns an Elasticsearch DSL search object or an iterator. + + .. note:: + + Calling ``list(search)`` over an DSL search object is the same as + calling ``search.execute().hits``. This is why an DSL search object + is compatible with DRF's paginator. + """ + filters = {} + filters['project'] = [p.slug for p in self._get_all_projects()] + filters['version'] = self._get_version().slug + + # Check to avoid searching all projects in case these filters are empty. + if not filters['project']: + log.info('Unable to find a project to search') + return [] + if not filters['version']: + log.info('Unable to find a version to search') + return [] + + query = self.request.query_params['q'] + queryset = PageSearch( + query=query, + filters=filters, + user=self.request.user, + # We use a permission class to control authorization + filter_by_user=False, + ) + return queryset + + def get_serializer_context(self): + context = super().get_serializer_context() + context['projects_data'] = self._get_all_projects_data() + return context + + def get(self, request, *args, **kwargs): + self._validate_query_params() + result = self.list() + self._record_query(result) + return result + + def list(self): + """List the results using pagination.""" + queryset = self.get_queryset() + page = self.paginator.paginate_queryset( + queryset, self.request, view=self, + ) + serializer = self.get_serializer(page, many=True) + return self.paginator.get_paginated_response(serializer.data) diff --git a/readthedocs/search/faceted_search.py b/readthedocs/search/faceted_search.py index e0aea1b6b68..89e0b16f882 100644 --- a/readthedocs/search/faceted_search.py +++ b/readthedocs/search/faceted_search.py @@ -116,8 +116,8 @@ class PageSearchBase(RTDFacetedSearch): # the score of and should be higher as it satisfies both or and and operators = ['and', 'or'] - def count(self): - """Overriding ``count`` method to return the count of the results after post_filter.""" + def total_count(self): + """Returns the total count of results of the current query.""" s = self.build_search() # setting size=0 so that no results are returned, diff --git a/readthedocs/search/tests/test_api.py b/readthedocs/search/tests/test_api.py index 73025856f37..9517dc50882 100644 --- a/readthedocs/search/tests/test_api.py +++ b/readthedocs/search/tests/test_api.py @@ -273,9 +273,9 @@ def test_doc_search_unexisting_version(self, api_client, project): resp = self.get_search(api_client, search_params) assert resp.status_code == 404 - @mock.patch.object(PageSearchAPIView, 'get_all_projects', list) + @mock.patch.object(PageSearchAPIView, '_get_all_projects', list) def test_get_all_projects_returns_empty_results(self, api_client, project): - """If there is a case where `get_all_projects` returns empty, we could be querying all projects.""" + """If there is a case where `_get_all_projects` returns empty, we could be querying all projects.""" # `documentation` word is present both in `kuma` and `docs` files # and not in `pipeline`, so search with this phrase but filter through project From 029857c9a8fab61fcc58d2696b83462d7e4c8790 Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Thu, 21 May 2020 16:14:52 -0500 Subject: [PATCH 2/3] Linter --- readthedocs/search/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/readthedocs/search/api.py b/readthedocs/search/api.py index 8cc4e03d3a8..96c916743bc 100644 --- a/readthedocs/search/api.py +++ b/readthedocs/search/api.py @@ -25,6 +25,7 @@ class PaginatorPage: + """ Mimics the result from a paginator. @@ -51,6 +52,7 @@ def previous_page_number(self): class SearchPagination(PageNumberPagination): + """Paginator for the results of PageSearch.""" page_size = 50 @@ -152,6 +154,7 @@ def get_inner_hits(self, obj): class PageSearchAPIView(GenericAPIView): + """ Main entry point to perform a search using Elasticsearch. From cc7cd52975d05e522004e50c9eb0de7fee6e3417 Mon Sep 17 00:00:00 2001 From: Santos Gallegos Date: Wed, 27 May 2020 18:59:14 -0500 Subject: [PATCH 3/3] Include more info in docstring --- readthedocs/search/api.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/readthedocs/search/api.py b/readthedocs/search/api.py index 96c916743bc..6325d1411fc 100644 --- a/readthedocs/search/api.py +++ b/readthedocs/search/api.py @@ -60,7 +60,14 @@ class SearchPagination(PageNumberPagination): max_page_size = 100 def paginate_queryset(self, queryset, request, view=None): - """Override to get the paginated result from the ES queryset.""" + """ + Override to get the paginated result from the ES queryset. + + This makes use of our custom paginator and slicing support from the ES DSL object, + instead of the one used by django's ORM. + + Mostly inspired by https://github.com/encode/django-rest-framework/blob/acbd9d8222e763c7f9c7dc2de23c430c702e06d4/rest_framework/pagination.py#L191 # noqa + """ # Needed for other methods of this class. self.request = request