Skip to content

Search: refactor API to not emulate a Django queryset #7114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 176 additions & 81 deletions readthedocs/search/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,113 @@
import itertools
import logging
import re
from functools import namedtuple
from math import ceil

from django.shortcuts import get_object_or_404
from django.utils import timezone
from rest_framework import generics, serializers
from rest_framework.exceptions import ValidationError
from django.utils.translation import ugettext as _
from rest_framework import serializers
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.generics import GenericAPIView
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
from rest_framework.utils.urls import remove_query_param, replace_query_param

from readthedocs.api.v2.permissions import IsAuthorizedToViewVersion
from readthedocs.builds.models import Version
from readthedocs.projects.constants import MKDOCS, SPHINX_HTMLDIR
from readthedocs.projects.models import HTMLFile, Project
from readthedocs.projects.models import Project
from readthedocs.search import tasks, utils
from readthedocs.search.faceted_search import PageSearch

log = logging.getLogger(__name__)


class PaginatorPage:

"""
Mimics the result from a paginator.

By using this class, we avoid having to override a lot of methods
of `PageNumberPagination` to make it work with the ES DSL object.
"""

def __init__(self, page_number, total_pages, count):
self.number = page_number
Paginator = namedtuple('Paginator', ['num_pages', 'count'])
self.paginator = Paginator(total_pages, count)

def has_next(self):
return self.number < self.paginator.num_pages

def has_previous(self):
return self.number > 0

def next_page_number(self):
return self.number + 1

def previous_page_number(self):
return self.number - 1


class SearchPagination(PageNumberPagination):

"""Paginator for the results of PageSearch."""

page_size = 50
page_size_query_param = 'page_size'
max_page_size = 100

def paginate_queryset(self, queryset, request, view=None):
"""
Override to get the paginated result from the ES queryset.

This makes use of our custom paginator and slicing support from the ES DSL object,
instead of the one used by django's ORM.

Mostly inspired by https://github.com/encode/django-rest-framework/blob/acbd9d8222e763c7f9c7dc2de23c430c702e06d4/rest_framework/pagination.py#L191 # noqa
"""
# Needed for other methods of this class.
self.request = request

page_size = self.get_page_size(request)

total_count = 0
total_pages = 1
if queryset:
total_count = queryset.total_count()
hits = max(1, total_count)
total_pages = ceil(hits / page_size)

page_number = request.query_params.get(self.page_query_param, 1)
if page_number in self.last_page_strings:
page_number = total_pages

if page_number <= 0:
msg = self.invalid_page_message.format(
page_number=page_number,
message=_("Invalid page"),
)
raise NotFound(msg)

if total_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True

start = (page_number - 1) * page_size
end = page_number * page_size
result = list(queryset[start:end])

# Needed for other methods of this class.
self.page = PaginatorPage(
page_number=page_number,
total_pages=total_pages,
count=total_count,
)

return result


class PageSearchSerializer(serializers.Serializer):
project = serializers.CharField()
Expand Down Expand Up @@ -75,12 +160,13 @@ def get_inner_hits(self, obj):
return sorted_results


class PageSearchAPIView(generics.ListAPIView):
class PageSearchAPIView(GenericAPIView):

"""
Main entry point to perform a search using Elasticsearch.

Required query params:

- q (search term)
- project
- version
Expand All @@ -91,6 +177,7 @@ class PageSearchAPIView(generics.ListAPIView):
are called many times, so a basic cache is implemented.
"""

http_method_names = ['get']
permission_classes = [IsAuthorizedToViewVersion]
pagination_class = SearchPagination
serializer_class = PageSearchSerializer
Expand Down Expand Up @@ -121,39 +208,7 @@ def _get_version(self):

return version

def get_queryset(self):
"""
Return Elasticsearch DSL Search object instead of Django Queryset.

Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern.
So for searching, its possible to return ``Search`` object instead of queryset.
The ``filter_backends`` and ``pagination_class`` is compatible with ``Search``
"""
# Validate all the required params are there
self.validate_query_params()
query = self.request.query_params.get('q', '')
filters = {}
filters['project'] = [p.slug for p in self.get_all_projects()]
filters['version'] = self._get_version().slug

# Check to avoid searching all projects in case these filters are empty.
if not filters['project']:
log.info("Unable to find a project to search")
return HTMLFile.objects.none()
if not filters['version']:
log.info("Unable to find a version to search")
return HTMLFile.objects.none()

queryset = PageSearch(
query=query,
filters=filters,
user=self.request.user,
# We use a permission class to control authorization
filter_by_user=False,
)
return queryset

def validate_query_params(self):
def _validate_query_params(self):
"""
Validate all required query params are passed on the request.

Expand All @@ -163,47 +218,16 @@ def validate_query_params(self):

:raises: ValidationError if one of them is missing.
"""
required_query_params = {'q', 'project', 'version'} # python `set` literal is `{}`
errors = {}
required_query_params = {'q', 'project', 'version'}
request_params = set(self.request.query_params.keys())
missing_params = required_query_params - request_params
if missing_params:
errors = {}
for param in missing_params:
errors[param] = ["This query param is required"]

for param in missing_params:
errors[param] = [_("This query param is required")]
if errors:
raise ValidationError(errors)

def get_serializer_context(self):
context = super().get_serializer_context()
context['projects_data'] = self.get_all_projects_data()
return context

def get_all_projects(self):
"""
Return a list of the project itself and all its subprojects the user has permissions over.

:rtype: list
"""
main_version = self._get_version()
main_project = self._get_project()

all_projects = [main_project]

subprojects = Project.objects.filter(
superprojects__parent_id=main_project.id,
)
for project in subprojects:
version = (
Version.internal
.public(user=self.request.user, project=project, include_hidden=False)
.filter(slug=main_version.slug)
.first()
)
if version:
all_projects.append(version.project)
return all_projects

def get_all_projects_data(self):
def _get_all_projects_data(self):
"""
Return a dict containing the project slug and its version URL and version's doctype.

Expand All @@ -224,7 +248,7 @@ def get_all_projects_data(self):

:rtype: dict
"""
all_projects = self.get_all_projects()
all_projects = self._get_all_projects()
version_slug = self._get_version().slug
project_urls = {}
for project in all_projects:
Expand All @@ -242,20 +266,41 @@ def get_all_projects_data(self):
}
return projects_data

def list(self, request, *args, **kwargs):
"""Overriding ``list`` method to record query in database."""
def _get_all_projects(self):
"""
Returns a list of the project itself and all its subprojects the user has permissions over.

:rtype: list
"""
main_version = self._get_version()
main_project = self._get_project()

response = super().list(request, *args, **kwargs)
all_projects = [main_project]

subprojects = Project.objects.filter(
superprojects__parent_id=main_project.id,
)
for project in subprojects:
version = (
Version.internal
.public(user=self.request.user, project=project, include_hidden=False)
.filter(slug=main_version.slug)
.first()
)
if version:
all_projects.append(version.project)
return all_projects

def _record_query(self, response):
project_slug = self._get_project().slug
version_slug = self._get_version().slug
total_results = response.data.get('count', 0)
time = timezone.now()

query = self.request.query_params.get('q', '')
query = self.request.query_params['q']
query = query.lower().strip()

# record the search query with a celery task
# Record the query with a celery task
tasks.record_search_query.delay(
project_slug,
version_slug,
Expand All @@ -264,4 +309,54 @@ def list(self, request, *args, **kwargs):
time.isoformat(),
)

return response
def get_queryset(self):
"""
Returns an Elasticsearch DSL search object or an iterator.

.. note::

Calling ``list(search)`` over an DSL search object is the same as
calling ``search.execute().hits``. This is why an DSL search object
is compatible with DRF's paginator.
"""
filters = {}
filters['project'] = [p.slug for p in self._get_all_projects()]
filters['version'] = self._get_version().slug

# Check to avoid searching all projects in case these filters are empty.
if not filters['project']:
log.info('Unable to find a project to search')
return []
if not filters['version']:
log.info('Unable to find a version to search')
return []

query = self.request.query_params['q']
queryset = PageSearch(
query=query,
filters=filters,
user=self.request.user,
# We use a permission class to control authorization
filter_by_user=False,
)
return queryset

def get_serializer_context(self):
context = super().get_serializer_context()
context['projects_data'] = self._get_all_projects_data()
return context

def get(self, request, *args, **kwargs):
self._validate_query_params()
result = self.list()
self._record_query(result)
return result

def list(self):
"""List the results using pagination."""
queryset = self.get_queryset()
page = self.paginator.paginate_queryset(
queryset, self.request, view=self,
)
serializer = self.get_serializer(page, many=True)
return self.paginator.get_paginated_response(serializer.data)
4 changes: 2 additions & 2 deletions readthedocs/search/faceted_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class PageSearchBase(RTDFacetedSearch):
# the score of and should be higher as it satisfies both or and and
operators = ['and', 'or']

def count(self):
"""Overriding ``count`` method to return the count of the results after post_filter."""
def total_count(self):
"""Returns the total count of results of the current query."""
s = self.build_search()

# setting size=0 so that no results are returned,
Expand Down
4 changes: 2 additions & 2 deletions readthedocs/search/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def test_doc_search_unexisting_version(self, api_client, project):
resp = self.get_search(api_client, search_params)
assert resp.status_code == 404

@mock.patch.object(PageSearchAPIView, 'get_all_projects', list)
@mock.patch.object(PageSearchAPIView, '_get_all_projects', list)
def test_get_all_projects_returns_empty_results(self, api_client, project):
"""If there is a case where `get_all_projects` returns empty, we could be querying all projects."""
"""If there is a case where `_get_all_projects` returns empty, we could be querying all projects."""

# `documentation` word is present both in `kuma` and `docs` files
# and not in `pipeline`, so search with this phrase but filter through project
Expand Down