Skip to content

Commit e4973ea

Browse files
authored
Merge pull request #7114 from readthedocs/refactor-search-apiview
Search: refactor API to not emulate a Django queryset
2 parents 4c6fe4b + cc7cd52 commit e4973ea

File tree

3 files changed

+180
-85
lines changed

3 files changed

+180
-85
lines changed

readthedocs/search/api.py

Lines changed: 176 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,113 @@
11
import itertools
22
import logging
33
import re
4+
from functools import namedtuple
5+
from math import ceil
46

57
from django.shortcuts import get_object_or_404
68
from django.utils import timezone
7-
from rest_framework import generics, serializers
8-
from rest_framework.exceptions import ValidationError
9+
from django.utils.translation import ugettext as _
10+
from rest_framework import serializers
11+
from rest_framework.exceptions import NotFound, ValidationError
12+
from rest_framework.generics import GenericAPIView
913
from rest_framework.pagination import PageNumberPagination
14+
from rest_framework.response import Response
15+
from rest_framework.utils.urls import remove_query_param, replace_query_param
1016

1117
from readthedocs.api.v2.permissions import IsAuthorizedToViewVersion
1218
from readthedocs.builds.models import Version
1319
from readthedocs.projects.constants import MKDOCS, SPHINX_HTMLDIR
14-
from readthedocs.projects.models import HTMLFile, Project
20+
from readthedocs.projects.models import Project
1521
from readthedocs.search import tasks, utils
1622
from readthedocs.search.faceted_search import PageSearch
1723

1824
log = logging.getLogger(__name__)
1925

2026

27+
class PaginatorPage:
28+
29+
"""
30+
Mimics the result from a paginator.
31+
32+
By using this class, we avoid having to override a lot of methods
33+
of `PageNumberPagination` to make it work with the ES DSL object.
34+
"""
35+
36+
def __init__(self, page_number, total_pages, count):
37+
self.number = page_number
38+
Paginator = namedtuple('Paginator', ['num_pages', 'count'])
39+
self.paginator = Paginator(total_pages, count)
40+
41+
def has_next(self):
42+
return self.number < self.paginator.num_pages
43+
44+
def has_previous(self):
45+
return self.number > 0
46+
47+
def next_page_number(self):
48+
return self.number + 1
49+
50+
def previous_page_number(self):
51+
return self.number - 1
52+
53+
2154
class SearchPagination(PageNumberPagination):
55+
56+
"""Paginator for the results of PageSearch."""
57+
2258
page_size = 50
2359
page_size_query_param = 'page_size'
2460
max_page_size = 100
2561

62+
def paginate_queryset(self, queryset, request, view=None):
63+
"""
64+
Override to get the paginated result from the ES queryset.
65+
66+
This makes use of our custom paginator and slicing support from the ES DSL object,
67+
instead of the one used by django's ORM.
68+
69+
Mostly inspired by https://github.com/encode/django-rest-framework/blob/acbd9d8222e763c7f9c7dc2de23c430c702e06d4/rest_framework/pagination.py#L191 # noqa
70+
"""
71+
# Needed for other methods of this class.
72+
self.request = request
73+
74+
page_size = self.get_page_size(request)
75+
76+
total_count = 0
77+
total_pages = 1
78+
if queryset:
79+
total_count = queryset.total_count()
80+
hits = max(1, total_count)
81+
total_pages = ceil(hits / page_size)
82+
83+
page_number = request.query_params.get(self.page_query_param, 1)
84+
if page_number in self.last_page_strings:
85+
page_number = total_pages
86+
87+
if page_number <= 0:
88+
msg = self.invalid_page_message.format(
89+
page_number=page_number,
90+
message=_("Invalid page"),
91+
)
92+
raise NotFound(msg)
93+
94+
if total_pages > 1 and self.template is not None:
95+
# The browsable API should display pagination controls.
96+
self.display_page_controls = True
97+
98+
start = (page_number - 1) * page_size
99+
end = page_number * page_size
100+
result = list(queryset[start:end])
101+
102+
# Needed for other methods of this class.
103+
self.page = PaginatorPage(
104+
page_number=page_number,
105+
total_pages=total_pages,
106+
count=total_count,
107+
)
108+
109+
return result
110+
26111

27112
class PageSearchSerializer(serializers.Serializer):
28113
project = serializers.CharField()
@@ -75,12 +160,13 @@ def get_inner_hits(self, obj):
75160
return sorted_results
76161

77162

78-
class PageSearchAPIView(generics.ListAPIView):
163+
class PageSearchAPIView(GenericAPIView):
79164

80165
"""
81166
Main entry point to perform a search using Elasticsearch.
82167
83168
Required query params:
169+
84170
- q (search term)
85171
- project
86172
- version
@@ -91,6 +177,7 @@ class PageSearchAPIView(generics.ListAPIView):
91177
are called many times, so a basic cache is implemented.
92178
"""
93179

180+
http_method_names = ['get']
94181
permission_classes = [IsAuthorizedToViewVersion]
95182
pagination_class = SearchPagination
96183
serializer_class = PageSearchSerializer
@@ -121,39 +208,7 @@ def _get_version(self):
121208

122209
return version
123210

124-
def get_queryset(self):
125-
"""
126-
Return Elasticsearch DSL Search object instead of Django Queryset.
127-
128-
Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern.
129-
So for searching, its possible to return ``Search`` object instead of queryset.
130-
The ``filter_backends`` and ``pagination_class`` is compatible with ``Search``
131-
"""
132-
# Validate all the required params are there
133-
self.validate_query_params()
134-
query = self.request.query_params.get('q', '')
135-
filters = {}
136-
filters['project'] = [p.slug for p in self.get_all_projects()]
137-
filters['version'] = self._get_version().slug
138-
139-
# Check to avoid searching all projects in case these filters are empty.
140-
if not filters['project']:
141-
log.info("Unable to find a project to search")
142-
return HTMLFile.objects.none()
143-
if not filters['version']:
144-
log.info("Unable to find a version to search")
145-
return HTMLFile.objects.none()
146-
147-
queryset = PageSearch(
148-
query=query,
149-
filters=filters,
150-
user=self.request.user,
151-
# We use a permission class to control authorization
152-
filter_by_user=False,
153-
)
154-
return queryset
155-
156-
def validate_query_params(self):
211+
def _validate_query_params(self):
157212
"""
158213
Validate all required query params are passed on the request.
159214
@@ -163,47 +218,16 @@ def validate_query_params(self):
163218
164219
:raises: ValidationError if one of them is missing.
165220
"""
166-
required_query_params = {'q', 'project', 'version'} # python `set` literal is `{}`
221+
errors = {}
222+
required_query_params = {'q', 'project', 'version'}
167223
request_params = set(self.request.query_params.keys())
168224
missing_params = required_query_params - request_params
169-
if missing_params:
170-
errors = {}
171-
for param in missing_params:
172-
errors[param] = ["This query param is required"]
173-
225+
for param in missing_params:
226+
errors[param] = [_("This query param is required")]
227+
if errors:
174228
raise ValidationError(errors)
175229

176-
def get_serializer_context(self):
177-
context = super().get_serializer_context()
178-
context['projects_data'] = self.get_all_projects_data()
179-
return context
180-
181-
def get_all_projects(self):
182-
"""
183-
Return a list of the project itself and all its subprojects the user has permissions over.
184-
185-
:rtype: list
186-
"""
187-
main_version = self._get_version()
188-
main_project = self._get_project()
189-
190-
all_projects = [main_project]
191-
192-
subprojects = Project.objects.filter(
193-
superprojects__parent_id=main_project.id,
194-
)
195-
for project in subprojects:
196-
version = (
197-
Version.internal
198-
.public(user=self.request.user, project=project, include_hidden=False)
199-
.filter(slug=main_version.slug)
200-
.first()
201-
)
202-
if version:
203-
all_projects.append(version.project)
204-
return all_projects
205-
206-
def get_all_projects_data(self):
230+
def _get_all_projects_data(self):
207231
"""
208232
Return a dict containing the project slug and its version URL and version's doctype.
209233
@@ -224,7 +248,7 @@ def get_all_projects_data(self):
224248
225249
:rtype: dict
226250
"""
227-
all_projects = self.get_all_projects()
251+
all_projects = self._get_all_projects()
228252
version_slug = self._get_version().slug
229253
project_urls = {}
230254
for project in all_projects:
@@ -242,20 +266,41 @@ def get_all_projects_data(self):
242266
}
243267
return projects_data
244268

245-
def list(self, request, *args, **kwargs):
246-
"""Overriding ``list`` method to record query in database."""
269+
def _get_all_projects(self):
270+
"""
271+
Returns a list of the project itself and all its subprojects the user has permissions over.
272+
273+
:rtype: list
274+
"""
275+
main_version = self._get_version()
276+
main_project = self._get_project()
247277

248-
response = super().list(request, *args, **kwargs)
278+
all_projects = [main_project]
279+
280+
subprojects = Project.objects.filter(
281+
superprojects__parent_id=main_project.id,
282+
)
283+
for project in subprojects:
284+
version = (
285+
Version.internal
286+
.public(user=self.request.user, project=project, include_hidden=False)
287+
.filter(slug=main_version.slug)
288+
.first()
289+
)
290+
if version:
291+
all_projects.append(version.project)
292+
return all_projects
249293

294+
def _record_query(self, response):
250295
project_slug = self._get_project().slug
251296
version_slug = self._get_version().slug
252297
total_results = response.data.get('count', 0)
253298
time = timezone.now()
254299

255-
query = self.request.query_params.get('q', '')
300+
query = self.request.query_params['q']
256301
query = query.lower().strip()
257302

258-
# record the search query with a celery task
303+
# Record the query with a celery task
259304
tasks.record_search_query.delay(
260305
project_slug,
261306
version_slug,
@@ -264,4 +309,54 @@ def list(self, request, *args, **kwargs):
264309
time.isoformat(),
265310
)
266311

267-
return response
312+
def get_queryset(self):
313+
"""
314+
Returns an Elasticsearch DSL search object or an iterator.
315+
316+
.. note::
317+
318+
Calling ``list(search)`` over an DSL search object is the same as
319+
calling ``search.execute().hits``. This is why an DSL search object
320+
is compatible with DRF's paginator.
321+
"""
322+
filters = {}
323+
filters['project'] = [p.slug for p in self._get_all_projects()]
324+
filters['version'] = self._get_version().slug
325+
326+
# Check to avoid searching all projects in case these filters are empty.
327+
if not filters['project']:
328+
log.info('Unable to find a project to search')
329+
return []
330+
if not filters['version']:
331+
log.info('Unable to find a version to search')
332+
return []
333+
334+
query = self.request.query_params['q']
335+
queryset = PageSearch(
336+
query=query,
337+
filters=filters,
338+
user=self.request.user,
339+
# We use a permission class to control authorization
340+
filter_by_user=False,
341+
)
342+
return queryset
343+
344+
def get_serializer_context(self):
345+
context = super().get_serializer_context()
346+
context['projects_data'] = self._get_all_projects_data()
347+
return context
348+
349+
def get(self, request, *args, **kwargs):
350+
self._validate_query_params()
351+
result = self.list()
352+
self._record_query(result)
353+
return result
354+
355+
def list(self):
356+
"""List the results using pagination."""
357+
queryset = self.get_queryset()
358+
page = self.paginator.paginate_queryset(
359+
queryset, self.request, view=self,
360+
)
361+
serializer = self.get_serializer(page, many=True)
362+
return self.paginator.get_paginated_response(serializer.data)

readthedocs/search/faceted_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ class PageSearchBase(RTDFacetedSearch):
116116
# the score of and should be higher as it satisfies both or and and
117117
operators = ['and', 'or']
118118

119-
def count(self):
120-
"""Overriding ``count`` method to return the count of the results after post_filter."""
119+
def total_count(self):
120+
"""Returns the total count of results of the current query."""
121121
s = self.build_search()
122122

123123
# setting size=0 so that no results are returned,

readthedocs/search/tests/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ def test_doc_search_unexisting_version(self, api_client, project):
273273
resp = self.get_search(api_client, search_params)
274274
assert resp.status_code == 404
275275

276-
@mock.patch.object(PageSearchAPIView, 'get_all_projects', list)
276+
@mock.patch.object(PageSearchAPIView, '_get_all_projects', list)
277277
def test_get_all_projects_returns_empty_results(self, api_client, project):
278-
"""If there is a case where `get_all_projects` returns empty, we could be querying all projects."""
278+
"""If there is a case where `_get_all_projects` returns empty, we could be querying all projects."""
279279

280280
# `documentation` word is present both in `kuma` and `docs` files
281281
# and not in `pipeline`, so search with this phrase but filter through project

0 commit comments

Comments
 (0)