1
1
import itertools
2
2
import logging
3
3
import re
4
+ from functools import namedtuple
5
+ from math import ceil
4
6
5
7
from django .shortcuts import get_object_or_404
6
8
from django .utils import timezone
7
- from rest_framework import generics , serializers
8
- from rest_framework .exceptions import ValidationError
9
+ from django .utils .translation import ugettext as _
10
+ from rest_framework import serializers
11
+ from rest_framework .exceptions import NotFound , ValidationError
12
+ from rest_framework .generics import GenericAPIView
9
13
from rest_framework .pagination import PageNumberPagination
14
+ from rest_framework .response import Response
15
+ from rest_framework .utils .urls import remove_query_param , replace_query_param
10
16
11
17
from readthedocs .api .v2 .permissions import IsAuthorizedToViewVersion
12
18
from readthedocs .builds .models import Version
13
19
from readthedocs .projects .constants import MKDOCS , SPHINX_HTMLDIR
14
- from readthedocs .projects .models import HTMLFile , Project
20
+ from readthedocs .projects .models import Project
15
21
from readthedocs .search import tasks , utils
16
22
from readthedocs .search .faceted_search import PageSearch
17
23
18
24
log = logging .getLogger (__name__ )
19
25
20
26
27
+ class PaginatorPage :
28
+
29
+ """
30
+ Mimics the result from a paginator.
31
+
32
+ By using this class, we avoid having to override a lot of methods
33
+ of `PageNumberPagination` to make it work with the ES DSL object.
34
+ """
35
+
36
+ def __init__ (self , page_number , total_pages , count ):
37
+ self .number = page_number
38
+ Paginator = namedtuple ('Paginator' , ['num_pages' , 'count' ])
39
+ self .paginator = Paginator (total_pages , count )
40
+
41
+ def has_next (self ):
42
+ return self .number < self .paginator .num_pages
43
+
44
+ def has_previous (self ):
45
+ return self .number > 0
46
+
47
+ def next_page_number (self ):
48
+ return self .number + 1
49
+
50
+ def previous_page_number (self ):
51
+ return self .number - 1
52
+
53
+
21
54
class SearchPagination (PageNumberPagination ):
55
+
56
+ """Paginator for the results of PageSearch."""
57
+
22
58
page_size = 50
23
59
page_size_query_param = 'page_size'
24
60
max_page_size = 100
25
61
62
+ def paginate_queryset (self , queryset , request , view = None ):
63
+ """
64
+ Override to get the paginated result from the ES queryset.
65
+
66
+ This makes use of our custom paginator and slicing support from the ES DSL object,
67
+ instead of the one used by django's ORM.
68
+
69
+ Mostly inspired by https://github.com/encode/django-rest-framework/blob/acbd9d8222e763c7f9c7dc2de23c430c702e06d4/rest_framework/pagination.py#L191 # noqa
70
+ """
71
+ # Needed for other methods of this class.
72
+ self .request = request
73
+
74
+ page_size = self .get_page_size (request )
75
+
76
+ total_count = 0
77
+ total_pages = 1
78
+ if queryset :
79
+ total_count = queryset .total_count ()
80
+ hits = max (1 , total_count )
81
+ total_pages = ceil (hits / page_size )
82
+
83
+ page_number = request .query_params .get (self .page_query_param , 1 )
84
+ if page_number in self .last_page_strings :
85
+ page_number = total_pages
86
+
87
+ if page_number <= 0 :
88
+ msg = self .invalid_page_message .format (
89
+ page_number = page_number ,
90
+ message = _ ("Invalid page" ),
91
+ )
92
+ raise NotFound (msg )
93
+
94
+ if total_pages > 1 and self .template is not None :
95
+ # The browsable API should display pagination controls.
96
+ self .display_page_controls = True
97
+
98
+ start = (page_number - 1 ) * page_size
99
+ end = page_number * page_size
100
+ result = list (queryset [start :end ])
101
+
102
+ # Needed for other methods of this class.
103
+ self .page = PaginatorPage (
104
+ page_number = page_number ,
105
+ total_pages = total_pages ,
106
+ count = total_count ,
107
+ )
108
+
109
+ return result
110
+
26
111
27
112
class PageSearchSerializer (serializers .Serializer ):
28
113
project = serializers .CharField ()
@@ -75,12 +160,13 @@ def get_inner_hits(self, obj):
75
160
return sorted_results
76
161
77
162
78
- class PageSearchAPIView (generics . ListAPIView ):
163
+ class PageSearchAPIView (GenericAPIView ):
79
164
80
165
"""
81
166
Main entry point to perform a search using Elasticsearch.
82
167
83
168
Required query params:
169
+
84
170
- q (search term)
85
171
- project
86
172
- version
@@ -91,6 +177,7 @@ class PageSearchAPIView(generics.ListAPIView):
91
177
are called many times, so a basic cache is implemented.
92
178
"""
93
179
180
+ http_method_names = ['get' ]
94
181
permission_classes = [IsAuthorizedToViewVersion ]
95
182
pagination_class = SearchPagination
96
183
serializer_class = PageSearchSerializer
@@ -121,39 +208,7 @@ def _get_version(self):
121
208
122
209
return version
123
210
124
- def get_queryset (self ):
125
- """
126
- Return Elasticsearch DSL Search object instead of Django Queryset.
127
-
128
- Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern.
129
- So for searching, its possible to return ``Search`` object instead of queryset.
130
- The ``filter_backends`` and ``pagination_class`` is compatible with ``Search``
131
- """
132
- # Validate all the required params are there
133
- self .validate_query_params ()
134
- query = self .request .query_params .get ('q' , '' )
135
- filters = {}
136
- filters ['project' ] = [p .slug for p in self .get_all_projects ()]
137
- filters ['version' ] = self ._get_version ().slug
138
-
139
- # Check to avoid searching all projects in case these filters are empty.
140
- if not filters ['project' ]:
141
- log .info ("Unable to find a project to search" )
142
- return HTMLFile .objects .none ()
143
- if not filters ['version' ]:
144
- log .info ("Unable to find a version to search" )
145
- return HTMLFile .objects .none ()
146
-
147
- queryset = PageSearch (
148
- query = query ,
149
- filters = filters ,
150
- user = self .request .user ,
151
- # We use a permission class to control authorization
152
- filter_by_user = False ,
153
- )
154
- return queryset
155
-
156
- def validate_query_params (self ):
211
+ def _validate_query_params (self ):
157
212
"""
158
213
Validate all required query params are passed on the request.
159
214
@@ -163,47 +218,16 @@ def validate_query_params(self):
163
218
164
219
:raises: ValidationError if one of them is missing.
165
220
"""
166
- required_query_params = {'q' , 'project' , 'version' } # python `set` literal is `{}`
221
+ errors = {}
222
+ required_query_params = {'q' , 'project' , 'version' }
167
223
request_params = set (self .request .query_params .keys ())
168
224
missing_params = required_query_params - request_params
169
- if missing_params :
170
- errors = {}
171
- for param in missing_params :
172
- errors [param ] = ["This query param is required" ]
173
-
225
+ for param in missing_params :
226
+ errors [param ] = [_ ("This query param is required" )]
227
+ if errors :
174
228
raise ValidationError (errors )
175
229
176
- def get_serializer_context (self ):
177
- context = super ().get_serializer_context ()
178
- context ['projects_data' ] = self .get_all_projects_data ()
179
- return context
180
-
181
- def get_all_projects (self ):
182
- """
183
- Return a list of the project itself and all its subprojects the user has permissions over.
184
-
185
- :rtype: list
186
- """
187
- main_version = self ._get_version ()
188
- main_project = self ._get_project ()
189
-
190
- all_projects = [main_project ]
191
-
192
- subprojects = Project .objects .filter (
193
- superprojects__parent_id = main_project .id ,
194
- )
195
- for project in subprojects :
196
- version = (
197
- Version .internal
198
- .public (user = self .request .user , project = project , include_hidden = False )
199
- .filter (slug = main_version .slug )
200
- .first ()
201
- )
202
- if version :
203
- all_projects .append (version .project )
204
- return all_projects
205
-
206
- def get_all_projects_data (self ):
230
+ def _get_all_projects_data (self ):
207
231
"""
208
232
Return a dict containing the project slug and its version URL and version's doctype.
209
233
@@ -224,7 +248,7 @@ def get_all_projects_data(self):
224
248
225
249
:rtype: dict
226
250
"""
227
- all_projects = self .get_all_projects ()
251
+ all_projects = self ._get_all_projects ()
228
252
version_slug = self ._get_version ().slug
229
253
project_urls = {}
230
254
for project in all_projects :
@@ -242,20 +266,41 @@ def get_all_projects_data(self):
242
266
}
243
267
return projects_data
244
268
245
- def list (self , request , * args , ** kwargs ):
246
- """Overriding ``list`` method to record query in database."""
269
+ def _get_all_projects (self ):
270
+ """
271
+ Returns a list of the project itself and all its subprojects the user has permissions over.
272
+
273
+ :rtype: list
274
+ """
275
+ main_version = self ._get_version ()
276
+ main_project = self ._get_project ()
247
277
248
- response = super ().list (request , * args , ** kwargs )
278
+ all_projects = [main_project ]
279
+
280
+ subprojects = Project .objects .filter (
281
+ superprojects__parent_id = main_project .id ,
282
+ )
283
+ for project in subprojects :
284
+ version = (
285
+ Version .internal
286
+ .public (user = self .request .user , project = project , include_hidden = False )
287
+ .filter (slug = main_version .slug )
288
+ .first ()
289
+ )
290
+ if version :
291
+ all_projects .append (version .project )
292
+ return all_projects
249
293
294
+ def _record_query (self , response ):
250
295
project_slug = self ._get_project ().slug
251
296
version_slug = self ._get_version ().slug
252
297
total_results = response .data .get ('count' , 0 )
253
298
time = timezone .now ()
254
299
255
- query = self .request .query_params . get ( 'q' , '' )
300
+ query = self .request .query_params [ 'q' ]
256
301
query = query .lower ().strip ()
257
302
258
- # record the search query with a celery task
303
+ # Record the query with a celery task
259
304
tasks .record_search_query .delay (
260
305
project_slug ,
261
306
version_slug ,
@@ -264,4 +309,54 @@ def list(self, request, *args, **kwargs):
264
309
time .isoformat (),
265
310
)
266
311
267
- return response
312
+ def get_queryset (self ):
313
+ """
314
+ Returns an Elasticsearch DSL search object or an iterator.
315
+
316
+ .. note::
317
+
318
+ Calling ``list(search)`` over an DSL search object is the same as
319
+ calling ``search.execute().hits``. This is why an DSL search object
320
+ is compatible with DRF's paginator.
321
+ """
322
+ filters = {}
323
+ filters ['project' ] = [p .slug for p in self ._get_all_projects ()]
324
+ filters ['version' ] = self ._get_version ().slug
325
+
326
+ # Check to avoid searching all projects in case these filters are empty.
327
+ if not filters ['project' ]:
328
+ log .info ('Unable to find a project to search' )
329
+ return []
330
+ if not filters ['version' ]:
331
+ log .info ('Unable to find a version to search' )
332
+ return []
333
+
334
+ query = self .request .query_params ['q' ]
335
+ queryset = PageSearch (
336
+ query = query ,
337
+ filters = filters ,
338
+ user = self .request .user ,
339
+ # We use a permission class to control authorization
340
+ filter_by_user = False ,
341
+ )
342
+ return queryset
343
+
344
+ def get_serializer_context (self ):
345
+ context = super ().get_serializer_context ()
346
+ context ['projects_data' ] = self ._get_all_projects_data ()
347
+ return context
348
+
349
+ def get (self , request , * args , ** kwargs ):
350
+ self ._validate_query_params ()
351
+ result = self .list ()
352
+ self ._record_query (result )
353
+ return result
354
+
355
+ def list (self ):
356
+ """List the results using pagination."""
357
+ queryset = self .get_queryset ()
358
+ page = self .paginator .paginate_queryset (
359
+ queryset , self .request , view = self ,
360
+ )
361
+ serializer = self .get_serializer (page , many = True )
362
+ return self .paginator .get_paginated_response (serializer .data )
0 commit comments