1
1
import itertools
2
2
import logging
3
3
import re
4
+ from functools import namedtuple
5
+ from math import ceil
4
6
5
7
from django .shortcuts import get_object_or_404
6
8
from django .utils import timezone
7
- from rest_framework import generics , serializers
8
- from rest_framework .exceptions import ValidationError
9
+ from django .utils .translation import ugettext as _
10
+ from rest_framework import serializers
11
+ from rest_framework .exceptions import NotFound , ValidationError
12
+ from rest_framework .generics import GenericAPIView
9
13
from rest_framework .pagination import PageNumberPagination
14
+ from rest_framework .response import Response
15
+ from rest_framework .utils .urls import remove_query_param , replace_query_param
10
16
11
17
from readthedocs .api .v2 .permissions import IsAuthorizedToViewVersion
12
18
from readthedocs .builds .models import Version
13
19
from readthedocs .projects .constants import MKDOCS , SPHINX_HTMLDIR
14
- from readthedocs .projects .models import HTMLFile , Project
20
+ from readthedocs .projects .models import Project
15
21
from readthedocs .search import tasks , utils
16
22
from readthedocs .search .faceted_search import PageSearch
17
23
18
24
log = logging .getLogger (__name__ )
19
25
20
26
27
+ class PaginatorPage :
28
+ """
29
+ Mimics the result from a paginator.
30
+
31
+ By using this class, we avoid having to override a lot of methods
32
+ of `PageNumberPagination` to make it work with the ES DSL object.
33
+ """
34
+
35
+ def __init__ (self , page_number , total_pages , count ):
36
+ self .number = page_number
37
+ Paginator = namedtuple ('Paginator' , ['num_pages' , 'count' ])
38
+ self .paginator = Paginator (total_pages , count )
39
+
40
+ def has_next (self ):
41
+ return self .number < self .paginator .num_pages
42
+
43
+ def has_previous (self ):
44
+ return self .number > 0
45
+
46
+ def next_page_number (self ):
47
+ return self .number + 1
48
+
49
+ def previous_page_number (self ):
50
+ return self .number - 1
51
+
52
+
21
53
class SearchPagination (PageNumberPagination ):
54
+ """Paginator for the results of PageSearch."""
55
+
22
56
page_size = 50
23
57
page_size_query_param = 'page_size'
24
58
max_page_size = 100
25
59
60
+ def paginate_queryset (self , queryset , request , view = None ):
61
+ """Override to get the paginated result from the ES queryset."""
62
+ # Needed for other methods of this class.
63
+ self .request = request
64
+
65
+ page_size = self .get_page_size (request )
66
+
67
+ total_count = 0
68
+ total_pages = 1
69
+ if queryset :
70
+ total_count = queryset .total_count ()
71
+ hits = max (1 , total_count )
72
+ total_pages = ceil (hits / page_size )
73
+
74
+ page_number = request .query_params .get (self .page_query_param , 1 )
75
+ if page_number in self .last_page_strings :
76
+ page_number = total_pages
77
+
78
+ if page_number <= 0 :
79
+ msg = self .invalid_page_message .format (
80
+ page_number = page_number ,
81
+ message = _ ("Invalid page" ),
82
+ )
83
+ raise NotFound (msg )
84
+
85
+ if total_pages > 1 and self .template is not None :
86
+ # The browsable API should display pagination controls.
87
+ self .display_page_controls = True
88
+
89
+ start = (page_number - 1 ) * page_size
90
+ end = page_number * page_size
91
+ result = list (queryset [start :end ])
92
+
93
+ # Needed for other methods of this class.
94
+ self .page = PaginatorPage (
95
+ page_number = page_number ,
96
+ total_pages = total_pages ,
97
+ count = total_count ,
98
+ )
99
+
100
+ return result
101
+
26
102
27
103
class PageSearchSerializer (serializers .Serializer ):
28
104
project = serializers .CharField ()
@@ -75,12 +151,12 @@ def get_inner_hits(self, obj):
75
151
return sorted_results
76
152
77
153
78
- class PageSearchAPIView (generics .ListAPIView ):
79
-
154
+ class PageSearchAPIView (GenericAPIView ):
80
155
"""
81
156
Main entry point to perform a search using Elasticsearch.
82
157
83
158
Required query params:
159
+
84
160
- q (search term)
85
161
- project
86
162
- version
@@ -91,6 +167,7 @@ class PageSearchAPIView(generics.ListAPIView):
91
167
are called many times, so a basic cache is implemented.
92
168
"""
93
169
170
+ http_method_names = ['get' ]
94
171
permission_classes = [IsAuthorizedToViewVersion ]
95
172
pagination_class = SearchPagination
96
173
serializer_class = PageSearchSerializer
@@ -121,39 +198,7 @@ def _get_version(self):
121
198
122
199
return version
123
200
124
- def get_queryset (self ):
125
- """
126
- Return Elasticsearch DSL Search object instead of Django Queryset.
127
-
128
- Django Queryset and elasticsearch-dsl ``Search`` object is similar pattern.
129
- So for searching, its possible to return ``Search`` object instead of queryset.
130
- The ``filter_backends`` and ``pagination_class`` is compatible with ``Search``
131
- """
132
- # Validate all the required params are there
133
- self .validate_query_params ()
134
- query = self .request .query_params .get ('q' , '' )
135
- filters = {}
136
- filters ['project' ] = [p .slug for p in self .get_all_projects ()]
137
- filters ['version' ] = self ._get_version ().slug
138
-
139
- # Check to avoid searching all projects in case these filters are empty.
140
- if not filters ['project' ]:
141
- log .info ("Unable to find a project to search" )
142
- return HTMLFile .objects .none ()
143
- if not filters ['version' ]:
144
- log .info ("Unable to find a version to search" )
145
- return HTMLFile .objects .none ()
146
-
147
- queryset = PageSearch (
148
- query = query ,
149
- filters = filters ,
150
- user = self .request .user ,
151
- # We use a permission class to control authorization
152
- filter_by_user = False ,
153
- )
154
- return queryset
155
-
156
- def validate_query_params (self ):
201
+ def _validate_query_params (self ):
157
202
"""
158
203
Validate all required query params are passed on the request.
159
204
@@ -163,47 +208,16 @@ def validate_query_params(self):
163
208
164
209
:raises: ValidationError if one of them is missing.
165
210
"""
166
- required_query_params = {'q' , 'project' , 'version' } # python `set` literal is `{}`
211
+ errors = {}
212
+ required_query_params = {'q' , 'project' , 'version' }
167
213
request_params = set (self .request .query_params .keys ())
168
214
missing_params = required_query_params - request_params
169
- if missing_params :
170
- errors = {}
171
- for param in missing_params :
172
- errors [param ] = ["This query param is required" ]
173
-
215
+ for param in missing_params :
216
+ errors [param ] = [_ ("This query param is required" )]
217
+ if errors :
174
218
raise ValidationError (errors )
175
219
176
- def get_serializer_context (self ):
177
- context = super ().get_serializer_context ()
178
- context ['projects_data' ] = self .get_all_projects_data ()
179
- return context
180
-
181
- def get_all_projects (self ):
182
- """
183
- Return a list of the project itself and all its subprojects the user has permissions over.
184
-
185
- :rtype: list
186
- """
187
- main_version = self ._get_version ()
188
- main_project = self ._get_project ()
189
-
190
- all_projects = [main_project ]
191
-
192
- subprojects = Project .objects .filter (
193
- superprojects__parent_id = main_project .id ,
194
- )
195
- for project in subprojects :
196
- version = (
197
- Version .internal
198
- .public (user = self .request .user , project = project , include_hidden = False )
199
- .filter (slug = main_version .slug )
200
- .first ()
201
- )
202
- if version :
203
- all_projects .append (version .project )
204
- return all_projects
205
-
206
- def get_all_projects_data (self ):
220
+ def _get_all_projects_data (self ):
207
221
"""
208
222
Return a dict containing the project slug and its version URL and version's doctype.
209
223
@@ -224,7 +238,7 @@ def get_all_projects_data(self):
224
238
225
239
:rtype: dict
226
240
"""
227
- all_projects = self .get_all_projects ()
241
+ all_projects = self ._get_all_projects ()
228
242
version_slug = self ._get_version ().slug
229
243
project_urls = {}
230
244
for project in all_projects :
@@ -242,20 +256,41 @@ def get_all_projects_data(self):
242
256
}
243
257
return projects_data
244
258
245
- def list (self , request , * args , ** kwargs ):
246
- """Overriding ``list`` method to record query in database."""
259
+ def _get_all_projects (self ):
260
+ """
261
+ Returns a list of the project itself and all its subprojects the user has permissions over.
262
+
263
+ :rtype: list
264
+ """
265
+ main_version = self ._get_version ()
266
+ main_project = self ._get_project ()
247
267
248
- response = super (). list ( request , * args , ** kwargs )
268
+ all_projects = [ main_project ]
249
269
270
+ subprojects = Project .objects .filter (
271
+ superprojects__parent_id = main_project .id ,
272
+ )
273
+ for project in subprojects :
274
+ version = (
275
+ Version .internal
276
+ .public (user = self .request .user , project = project , include_hidden = False )
277
+ .filter (slug = main_version .slug )
278
+ .first ()
279
+ )
280
+ if version :
281
+ all_projects .append (version .project )
282
+ return all_projects
283
+
284
+ def _record_query (self , response ):
250
285
project_slug = self ._get_project ().slug
251
286
version_slug = self ._get_version ().slug
252
287
total_results = response .data .get ('count' , 0 )
253
288
time = timezone .now ()
254
289
255
- query = self .request .query_params . get ( 'q' , '' )
290
+ query = self .request .query_params [ 'q' ]
256
291
query = query .lower ().strip ()
257
292
258
- # record the search query with a celery task
293
+ # Record the query with a celery task
259
294
tasks .record_search_query .delay (
260
295
project_slug ,
261
296
version_slug ,
@@ -264,4 +299,54 @@ def list(self, request, *args, **kwargs):
264
299
time .isoformat (),
265
300
)
266
301
267
- return response
302
+ def get_queryset (self ):
303
+ """
304
+ Returns an Elasticsearch DSL search object or an iterator.
305
+
306
+ .. note::
307
+
308
+ Calling ``list(search)`` over an DSL search object is the same as
309
+ calling ``search.execute().hits``. This is why an DSL search object
310
+ is compatible with DRF's paginator.
311
+ """
312
+ filters = {}
313
+ filters ['project' ] = [p .slug for p in self ._get_all_projects ()]
314
+ filters ['version' ] = self ._get_version ().slug
315
+
316
+ # Check to avoid searching all projects in case these filters are empty.
317
+ if not filters ['project' ]:
318
+ log .info ('Unable to find a project to search' )
319
+ return []
320
+ if not filters ['version' ]:
321
+ log .info ('Unable to find a version to search' )
322
+ return []
323
+
324
+ query = self .request .query_params ['q' ]
325
+ queryset = PageSearch (
326
+ query = query ,
327
+ filters = filters ,
328
+ user = self .request .user ,
329
+ # We use a permission class to control authorization
330
+ filter_by_user = False ,
331
+ )
332
+ return queryset
333
+
334
+ def get_serializer_context (self ):
335
+ context = super ().get_serializer_context ()
336
+ context ['projects_data' ] = self ._get_all_projects_data ()
337
+ return context
338
+
339
+ def get (self , request , * args , ** kwargs ):
340
+ self ._validate_query_params ()
341
+ result = self .list ()
342
+ self ._record_query (result )
343
+ return result
344
+
345
+ def list (self ):
346
+ """List the results using pagination."""
347
+ queryset = self .get_queryset ()
348
+ page = self .paginator .paginate_queryset (
349
+ queryset , self .request , view = self ,
350
+ )
351
+ serializer = self .get_serializer (page , many = True )
352
+ return self .paginator .get_paginated_response (serializer .data )
0 commit comments