1
1
import itertools
2
2
import logging
3
3
4
+ from django .shortcuts import get_object_or_404
4
5
from django .utils import timezone
5
6
from rest_framework import generics , serializers
6
7
from rest_framework .exceptions import ValidationError
7
8
from rest_framework .pagination import PageNumberPagination
8
9
9
- from readthedocs .projects .models import HTMLFile
10
+ from readthedocs .api .v2 .permissions import IsAuthorizedToViewVersion
11
+ from readthedocs .builds .models import Version
12
+ from readthedocs .projects .models import HTMLFile , Project
10
13
from readthedocs .search import tasks , utils
11
14
from readthedocs .search .faceted_search import PageSearch
12
15
13
-
14
16
log = logging .getLogger (__name__ )
15
17
16
18
@@ -60,11 +62,50 @@ def get_inner_hits(self, obj):
60
62
61
63
class PageSearchAPIView (generics .ListAPIView ):
62
64
63
- """Main entry point to perform a search using Elasticsearch."""
65
+ """
66
+ Main entry point to perform a search using Elasticsearch.
67
+
68
+ Required query params:
69
+ - q (search term)
70
+ - project
71
+ - version
72
+
73
+ .. note::
74
+
75
+ The methods `_get_project` and `_get_version`
76
+ are called many times, so a basic cache is implemented.
77
+ """
64
78
79
+ permission_classes = [IsAuthorizedToViewVersion ]
65
80
pagination_class = SearchPagination
66
81
serializer_class = PageSearchSerializer
67
82
83
+ def _get_project (self ):
84
+ cache_key = '_cached_project'
85
+ project = getattr (self , cache_key , None )
86
+
87
+ if not project :
88
+ project_slug = self .request .GET .get ('project' , None )
89
+ project = get_object_or_404 (Project , slug = project_slug )
90
+ setattr (self , cache_key , project )
91
+
92
+ return project
93
+
94
+ def _get_version (self ):
95
+ cache_key = '_cached_version'
96
+ version = getattr (self , cache_key , None )
97
+
98
+ if not version :
99
+ version_slug = self .request .GET .get ('version' , None )
100
+ project = self ._get_project ()
101
+ version = get_object_or_404 (
102
+ project .versions .all (),
103
+ slug = version_slug ,
104
+ )
105
+ setattr (self , cache_key , version )
106
+
107
+ return version
108
+
68
109
def get_queryset (self ):
69
110
"""
70
111
Return Elasticsearch DSL Search object instead of Django Queryset.
@@ -78,13 +119,7 @@ def get_queryset(self):
78
119
query = self .request .query_params .get ('q' , '' )
79
120
kwargs = {'filter_by_user' : False , 'filters' : {}}
80
121
kwargs ['filters' ]['project' ] = [p .slug for p in self .get_all_projects ()]
81
- kwargs ['filters' ]['version' ] = self .request .query_params .get ('version' )
82
- if not kwargs ['filters' ]['project' ]:
83
- log .info ("Unable to find a project to search" )
84
- return HTMLFile .objects .none ()
85
- if not kwargs ['filters' ]['version' ]:
86
- log .info ("Unable to find a version to search" )
87
- return HTMLFile .objects .none ()
122
+ kwargs ['filters' ]['version' ] = self ._get_version ().slug
88
123
user = self .request .user
89
124
queryset = PageSearch (
90
125
query = query , user = user , ** kwargs
@@ -120,17 +155,24 @@ def get_all_projects(self):
120
155
"""
121
156
Return a list containing the project itself and all its subprojects.
122
157
123
- The project slug is retrieved from ``project`` query param.
124
-
125
158
:rtype: list
126
-
127
- :raises: Http404 if project is not found
128
159
"""
129
- project_slug = self .request .query_params .get ('project' )
130
- version_slug = self .request .query_params .get ('version' )
131
- all_projects = utils .get_project_list_or_404 (
132
- project_slug = project_slug , user = self .request .user , version_slug = version_slug ,
160
+ main_version = self ._get_version ()
161
+ main_project = self ._get_project ()
162
+
163
+ subprojects = Project .objects .filter (
164
+ superprojects__parent_id = main_project .id ,
133
165
)
166
+ all_projects = []
167
+ for project in list (subprojects ) + [main_project ]:
168
+ version = (
169
+ Version .objects
170
+ .public (user = self .request .user , project = project )
171
+ .filter (slug = main_version .slug )
172
+ .first ()
173
+ )
174
+ if version :
175
+ all_projects .append (version .project )
134
176
return all_projects
135
177
136
178
def get_all_projects_url (self ):
@@ -151,7 +193,7 @@ def get_all_projects_url(self):
151
193
:rtype: dict
152
194
"""
153
195
all_projects = self .get_all_projects ()
154
- version_slug = self .request . query_params . get ( 'version' )
196
+ version_slug = self ._get_version (). slug
155
197
projects_url = {}
156
198
for project in all_projects :
157
199
projects_url [project .slug ] = project .get_docs_url (version_slug = version_slug )
@@ -162,8 +204,8 @@ def list(self, request, *args, **kwargs):
162
204
163
205
response = super ().list (request , * args , ** kwargs )
164
206
165
- project_slug = self .request . query_params . get ( 'project' , None )
166
- version_slug = self .request . query_params . get ( 'version' , None )
207
+ project_slug = self ._get_project (). slug
208
+ version_slug = self ._get_version (). slug
167
209
total_results = response .data .get ('count' , 0 )
168
210
time = timezone .now ()
169
211
0 commit comments