Skip to content

Commit edcaceb

Browse files
committed
Search: refactor API to not emulate a Django queryset
The recommended way of using pagination over a custom object is to manage the class ourselves. I tried rely on most of the defaults of the DRF's pagination class. Closes #5235
1 parent 0d7901e commit edcaceb

File tree

3 files changed

+171
-86
lines changed

3 files changed

+171
-86
lines changed

readthedocs/search/api.py

Lines changed: 167 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,104 @@
11
import itertools
22
import logging
33
import re
4+
from functools import namedtuple
5+
from math import ceil
46

57
from django.shortcuts import get_object_or_404
68
from django.utils import timezone
7-
from rest_framework import generics, serializers
8-
from rest_framework.exceptions import ValidationError
9+
from django.utils.translation import ugettext as _
10+
from rest_framework import serializers
11+
from rest_framework.exceptions import NotFound, ValidationError
12+
from rest_framework.generics import GenericAPIView
913
from rest_framework.pagination import PageNumberPagination
14+
from rest_framework.response import Response
15+
from rest_framework.utils.urls import remove_query_param, replace_query_param
1016

1117
from readthedocs.api.v2.permissions import IsAuthorizedToViewVersion
1218
from readthedocs.builds.models import Version
1319
from readthedocs.projects.constants import MKDOCS, SPHINX_HTMLDIR
14-
from readthedocs.projects.models import HTMLFile, Project
20+
from readthedocs.projects.models import Project
1521
from readthedocs.search import tasks, utils
1622
from readthedocs.search.faceted_search import PageSearch
1723

1824
log = logging.getLogger(__name__)
1925

2026

27+
class PaginatorPage:
28+
"""
29+
Mimics the result from a paginator.
30+
31+
By using this class, we avoid having to override a lot of methods
32+
of `PageNumberPagination` to make it work with the ES DSL object.
33+
"""
34+
35+
def __init__(self, page_number, total_pages, count):
36+
self.number = page_number
37+
Paginator = namedtuple('Paginator', ['num_pages', 'count'])
38+
self.paginator = Paginator(total_pages, count)
39+
40+
def has_next(self):
41+
return self.number < self.paginator.num_pages
42+
43+
def has_previous(self):
44+
return self.number > 0
45+
46+
def next_page_number(self):
47+
return self.number + 1
48+
49+
def previous_page_number(self):
50+
return self.number - 1
51+
52+
2153
class SearchPagination(PageNumberPagination):
54+
"""Paginator for the results of PageSearch."""
55+
2256
page_size = 50
2357
page_size_query_param = 'page_size'
2458
max_page_size = 100
2559

60+
def paginate_queryset(self, queryset, request, view=None):
61+
"""Override to get the paginated result from the ES queryset."""
62+
# Needed for other methods of this class.
63+
self.request = request
64+
65+
page_size = self.get_page_size(request)
66+
67+
total_count = 0
68+
total_pages = 1
69+
if queryset:
70+
total_count = queryset.total_count()
71+
hits = max(1, total_count)
72+
total_pages = ceil(hits / page_size)
73+
74+
page_number = request.query_params.get(self.page_query_param, 1)
75+
if page_number in self.last_page_strings:
76+
page_number = total_pages
77+
78+
if page_number <= 0:
79+
msg = self.invalid_page_message.format(
80+
page_number=page_number,
81+
message=_("Invalid page"),
82+
)
83+
raise NotFound(msg)
84+
85+
if total_pages > 1 and self.template is not None:
86+
# The browsable API should display pagination controls.
87+
self.display_page_controls = True
88+
89+
start = (page_number - 1) * page_size
90+
end = page_number * page_size
91+
result = list(queryset[start:end])
92+
93+
# Needed for other methods of this class.
94+
self.page = PaginatorPage(
95+
page_number=page_number,
96+
total_pages=total_pages,
97+
count=total_count,
98+
)
99+
100+
return result
101+
26102

27103
class PageSearchSerializer(serializers.Serializer):
28104
project = serializers.CharField()
@@ -75,12 +151,12 @@ def get_inner_hits(self, obj):
75151
return sorted_results
76152

77153

78-
class PageSearchAPIView(generics.ListAPIView):
79-
154+
class PageSearchAPIView(GenericAPIView):
80155
"""
81156
Main entry point to perform a search using Elasticsearch.
82157
83158
Required query params:
159+
84160
- q (search term)
85161
- project
86162
- version
@@ -91,6 +167,7 @@ class PageSearchAPIView(generics.ListAPIView):
91167
are called many times, so a basic cache is implemented.
92168
"""
93169

170+
http_method_names = ['get']
94171
permission_classes = [IsAuthorizedToViewVersion]
95172
pagination_class = SearchPagination
96173
serializer_class = PageSearchSerializer
@@ -121,39 +198,7 @@ def _get_version(self):
121198

122199
return version
123200

124-
def get_queryset(self):
125-
"""
126-
Return Elasticsearch DSL Search object instead of Django Queryset.
127-
128-
Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern.
129-
So for searching, its possible to return ``Search`` object instead of queryset.
130-
The ``filter_backends`` and ``pagination_class`` is compatible with ``Search``
131-
"""
132-
# Validate all the required params are there
133-
self.validate_query_params()
134-
query = self.request.query_params.get('q', '')
135-
filters = {}
136-
filters['project'] = [p.slug for p in self.get_all_projects()]
137-
filters['version'] = self._get_version().slug
138-
139-
# Check to avoid searching all projects in case these filters are empty.
140-
if not filters['project']:
141-
log.info("Unable to find a project to search")
142-
return HTMLFile.objects.none()
143-
if not filters['version']:
144-
log.info("Unable to find a version to search")
145-
return HTMLFile.objects.none()
146-
147-
queryset = PageSearch(
148-
query=query,
149-
filters=filters,
150-
user=self.request.user,
151-
# We use a permission class to control authorization
152-
filter_by_user=False,
153-
)
154-
return queryset
155-
156-
def validate_query_params(self):
201+
def _validate_query_params(self):
157202
"""
158203
Validate all required query params are passed on the request.
159204
@@ -163,47 +208,16 @@ def validate_query_params(self):
163208
164209
:raises: ValidationError if one of them is missing.
165210
"""
166-
required_query_params = {'q', 'project', 'version'} # python `set` literal is `{}`
211+
errors = {}
212+
required_query_params = {'q', 'project', 'version'}
167213
request_params = set(self.request.query_params.keys())
168214
missing_params = required_query_params - request_params
169-
if missing_params:
170-
errors = {}
171-
for param in missing_params:
172-
errors[param] = ["This query param is required"]
173-
215+
for param in missing_params:
216+
errors[param] = [_("This query param is required")]
217+
if errors:
174218
raise ValidationError(errors)
175219

176-
def get_serializer_context(self):
177-
context = super().get_serializer_context()
178-
context['projects_data'] = self.get_all_projects_data()
179-
return context
180-
181-
def get_all_projects(self):
182-
"""
183-
Return a list of the project itself and all its subprojects the user has permissions over.
184-
185-
:rtype: list
186-
"""
187-
main_version = self._get_version()
188-
main_project = self._get_project()
189-
190-
all_projects = [main_project]
191-
192-
subprojects = Project.objects.filter(
193-
superprojects__parent_id=main_project.id,
194-
)
195-
for project in subprojects:
196-
version = (
197-
Version.internal
198-
.public(user=self.request.user, project=project, include_hidden=False)
199-
.filter(slug=main_version.slug)
200-
.first()
201-
)
202-
if version:
203-
all_projects.append(version.project)
204-
return all_projects
205-
206-
def get_all_projects_data(self):
220+
def _get_all_projects_data(self):
207221
"""
208222
Return a dict containing the project slug and its version URL and version's doctype.
209223
@@ -224,7 +238,7 @@ def get_all_projects_data(self):
224238
225239
:rtype: dict
226240
"""
227-
all_projects = self.get_all_projects()
241+
all_projects = self._get_all_projects()
228242
version_slug = self._get_version().slug
229243
project_urls = {}
230244
for project in all_projects:
@@ -242,20 +256,41 @@ def get_all_projects_data(self):
242256
}
243257
return projects_data
244258

245-
def list(self, request, *args, **kwargs):
246-
"""Overriding ``list`` method to record query in database."""
259+
def _get_all_projects(self):
260+
"""
261+
Returns a list of the project itself and all its subprojects the user has permissions over.
262+
263+
:rtype: list
264+
"""
265+
main_version = self._get_version()
266+
main_project = self._get_project()
247267

248-
response = super().list(request, *args, **kwargs)
268+
all_projects = [main_project]
249269

270+
subprojects = Project.objects.filter(
271+
superprojects__parent_id=main_project.id,
272+
)
273+
for project in subprojects:
274+
version = (
275+
Version.internal
276+
.public(user=self.request.user, project=project, include_hidden=False)
277+
.filter(slug=main_version.slug)
278+
.first()
279+
)
280+
if version:
281+
all_projects.append(version.project)
282+
return all_projects
283+
284+
def _record_query(self, response):
250285
project_slug = self._get_project().slug
251286
version_slug = self._get_version().slug
252287
total_results = response.data.get('count', 0)
253288
time = timezone.now()
254289

255-
query = self.request.query_params.get('q', '')
290+
query = self.request.query_params['q']
256291
query = query.lower().strip()
257292

258-
# record the search query with a celery task
293+
# Record the query with a celery task
259294
tasks.record_search_query.delay(
260295
project_slug,
261296
version_slug,
@@ -264,4 +299,54 @@ def list(self, request, *args, **kwargs):
264299
time.isoformat(),
265300
)
266301

267-
return response
302+
def get_queryset(self):
303+
"""
304+
Returns an Elasticsearch DSL search object or an iterator.
305+
306+
.. note::
307+
308+
Calling ``list(search)`` over an DSL search object is the same as
309+
calling ``search.execute().hits``. This is why an DSL search object
310+
is compatible with DRF's paginator.
311+
"""
312+
filters = {}
313+
filters['project'] = [p.slug for p in self._get_all_projects()]
314+
filters['version'] = self._get_version().slug
315+
316+
# Check to avoid searching all projects in case these filters are empty.
317+
if not filters['project']:
318+
log.info('Unable to find a project to search')
319+
return []
320+
if not filters['version']:
321+
log.info('Unable to find a version to search')
322+
return []
323+
324+
query = self.request.query_params['q']
325+
queryset = PageSearch(
326+
query=query,
327+
filters=filters,
328+
user=self.request.user,
329+
# We use a permission class to control authorization
330+
filter_by_user=False,
331+
)
332+
return queryset
333+
334+
def get_serializer_context(self):
335+
context = super().get_serializer_context()
336+
context['projects_data'] = self._get_all_projects_data()
337+
return context
338+
339+
def get(self, request, *args, **kwargs):
340+
self._validate_query_params()
341+
result = self.list()
342+
self._record_query(result)
343+
return result
344+
345+
def list(self):
346+
"""List the results using pagination."""
347+
queryset = self.get_queryset()
348+
page = self.paginator.paginate_queryset(
349+
queryset, self.request, view=self,
350+
)
351+
serializer = self.get_serializer(page, many=True)
352+
return self.paginator.get_paginated_response(serializer.data)

readthedocs/search/faceted_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ class PageSearchBase(RTDFacetedSearch):
116116
# the score of and should be higher as it satisfies both or and and
117117
operators = ['and', 'or']
118118

119-
def count(self):
120-
"""Overriding ``count`` method to return the count of the results after post_filter."""
119+
def total_count(self):
120+
"""Returns the total count of results of the current query."""
121121
s = self.build_search()
122122

123123
# setting size=0 so that no results are returned,

readthedocs/search/tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ def test_doc_search_unexisting_version(self, api_client, project):
273273
resp = self.get_search(api_client, search_params)
274274
assert resp.status_code == 404
275275

276-
@mock.patch.object(PageSearchAPIView, 'get_all_projects', list)
276+
@mock.patch.object(PageSearchAPIView, '_get_all_projects', list)
277277
def test_get_all_projects_returns_empty_results(self, api_client, project):
278-
"""If there is a case where `get_all_projects` returns empty, we could be querying all projects."""
278+
"""If there is a case where `_get_all_projects` returns empty, we could be querying all projects."""
279279

280280
# `documentation` word is present both in `kuma` and `docs` files
281281
# and not in `pipeline`, so search with this phrase but filter through project

0 commit comments

Comments
 (0)