1
1
import itertools
2
2
import logging
3
3
4
+ from django .shortcuts import get_object_or_404
4
5
from django .utils import timezone
5
6
from rest_framework import generics , serializers
6
7
from rest_framework .exceptions import ValidationError
7
8
from rest_framework .pagination import PageNumberPagination
8
9
9
- from readthedocs .projects .models import HTMLFile
10
+ from readthedocs .api .v2 .permissions import IsAuthorizedToViewVersion
11
+ from readthedocs .builds .models import Version
12
+ from readthedocs .projects .models import HTMLFile , Project
10
13
from readthedocs .search import tasks , utils
11
14
from readthedocs .search .faceted_search import PageSearch
12
15
13
-
14
16
log = logging .getLogger (__name__ )
15
17
16
18
@@ -60,11 +62,50 @@ def get_inner_hits(self, obj):
60
62
61
63
class PageSearchAPIView (generics .ListAPIView ):
62
64
63
- """Main entry point to perform a search using Elasticsearch."""
65
+ """
66
+ Main entry point to perform a search using Elasticsearch.
67
+
68
+ Required query params:
69
+ - q (search term)
70
+ - project
71
+ - version
72
+
73
+ .. note::
74
+
75
+ The methods `_get_project` and `_get_version`
76
+ are called many times, so a basic cache is implemented.
77
+ """
64
78
79
+ permission_classes = [IsAuthorizedToViewVersion ]
65
80
pagination_class = SearchPagination
66
81
serializer_class = PageSearchSerializer
67
82
83
+ def _get_project (self ):
84
+ cache_key = '_cached_project'
85
+ project = getattr (self , cache_key , None )
86
+
87
+ if not project :
88
+ project_slug = self .request .GET .get ('project' , None )
89
+ project = get_object_or_404 (Project , slug = project_slug )
90
+ setattr (self , cache_key , project )
91
+
92
+ return project
93
+
94
+ def _get_version (self ):
95
+ cache_key = '_cached_version'
96
+ version = getattr (self , cache_key , None )
97
+
98
+ if not version :
99
+ version_slug = self .request .GET .get ('version' , None )
100
+ project = self ._get_project ()
101
+ version = get_object_or_404 (
102
+ project .versions .all (),
103
+ slug = version_slug ,
104
+ )
105
+ setattr (self , cache_key , version )
106
+
107
+ return version
108
+
68
109
def get_queryset (self ):
69
110
"""
70
111
Return Elasticsearch DSL Search object instead of Django Queryset.
@@ -78,7 +119,8 @@ def get_queryset(self):
78
119
query = self .request .query_params .get ('q' , '' )
79
120
kwargs = {'filter_by_user' : False , 'filters' : {}}
80
121
kwargs ['filters' ]['project' ] = [p .slug for p in self .get_all_projects ()]
81
- kwargs ['filters' ]['version' ] = self .request .query_params .get ('version' )
122
+ kwargs ['filters' ]['version' ] = self ._get_version ().slug
123
+ # Check to avoid searching all projects in case project is empty.
82
124
if not kwargs ['filters' ]['project' ]:
83
125
log .info ("Unable to find a project to search" )
84
126
return HTMLFile .objects .none ()
@@ -118,19 +160,26 @@ def get_serializer_context(self):
118
160
119
161
def get_all_projects (self ):
120
162
"""
121
- Return a list containing the project itself and all its subprojects.
122
-
123
- The project slug is retrieved from ``project`` query param.
163
+ Return a list of the project itself and all its subprojects the user has permissions over.
124
164
125
165
:rtype: list
126
-
127
- :raises: Http404 if project is not found
128
166
"""
129
- project_slug = self .request .query_params .get ('project' )
130
- version_slug = self .request .query_params .get ('version' )
131
- all_projects = utils .get_project_list_or_404 (
132
- project_slug = project_slug , user = self .request .user , version_slug = version_slug ,
167
+ main_version = self ._get_version ()
168
+ main_project = self ._get_project ()
169
+
170
+ subprojects = Project .objects .filter (
171
+ superprojects__parent_id = main_project .id ,
133
172
)
173
+ all_projects = []
174
+ for project in list (subprojects ) + [main_project ]:
175
+ version = (
176
+ Version .internal
177
+ .public (user = self .request .user , project = project )
178
+ .filter (slug = main_version .slug )
179
+ .first ()
180
+ )
181
+ if version :
182
+ all_projects .append (version .project )
134
183
return all_projects
135
184
136
185
def get_all_projects_url (self ):
@@ -151,7 +200,7 @@ def get_all_projects_url(self):
151
200
:rtype: dict
152
201
"""
153
202
all_projects = self .get_all_projects ()
154
- version_slug = self .request . query_params . get ( 'version' )
203
+ version_slug = self ._get_version (). slug
155
204
projects_url = {}
156
205
for project in all_projects :
157
206
projects_url [project .slug ] = project .get_docs_url (version_slug = version_slug )
@@ -162,8 +211,8 @@ def list(self, request, *args, **kwargs):
162
211
163
212
response = super ().list (request , * args , ** kwargs )
164
213
165
- project_slug = self .request . query_params . get ( 'project' , None )
166
- version_slug = self .request . query_params . get ( 'version' , None )
214
+ project_slug = self ._get_project (). slug
215
+ version_slug = self ._get_version (). slug
167
216
total_results = response .data .get ('count' , 0 )
168
217
time = timezone .now ()
169
218
0 commit comments