Skip to content

Commit 1b8e006

Browse files
authored
Merge pull request #6761 from readthedocs/refactor-search-view
Refactor search view to make use of permission_classes
2 parents 23f9d19 + 0f93913 commit 1b8e006

File tree

4 files changed

+90
-40
lines changed

4 files changed

+90
-40
lines changed

readthedocs/api/v2/permissions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class IsAuthorizedToViewVersion(permissions.BasePermission):
9191
"""
9292
Checks if the user from the request has permissions to see the version.
9393
94-
This permission class used in the FooterHTML view.
94+
This permission class used in the FooterHTML and PageSearchAPIView views.
9595
9696
.. note::
9797

readthedocs/search/api.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import itertools
22
import logging
33

4+
from django.shortcuts import get_object_or_404
45
from django.utils import timezone
56
from rest_framework import generics, serializers
67
from rest_framework.exceptions import ValidationError
78
from rest_framework.pagination import PageNumberPagination
89

9-
from readthedocs.projects.models import HTMLFile
10+
from readthedocs.api.v2.permissions import IsAuthorizedToViewVersion
11+
from readthedocs.builds.models import Version
12+
from readthedocs.projects.models import HTMLFile, Project
1013
from readthedocs.search import tasks, utils
1114
from readthedocs.search.faceted_search import PageSearch
1215

13-
1416
log = logging.getLogger(__name__)
1517

1618

@@ -60,11 +62,50 @@ def get_inner_hits(self, obj):
6062

6163
class PageSearchAPIView(generics.ListAPIView):
6264

63-
"""Main entry point to perform a search using Elasticsearch."""
65+
"""
66+
Main entry point to perform a search using Elasticsearch.
67+
68+
Required query params:
69+
- q (search term)
70+
- project
71+
- version
72+
73+
.. note::
74+
75+
The methods `_get_project` and `_get_version`
76+
are called many times, so a basic cache is implemented.
77+
"""
6478

79+
permission_classes = [IsAuthorizedToViewVersion]
6580
pagination_class = SearchPagination
6681
serializer_class = PageSearchSerializer
6782

83+
def _get_project(self):
84+
cache_key = '_cached_project'
85+
project = getattr(self, cache_key, None)
86+
87+
if not project:
88+
project_slug = self.request.GET.get('project', None)
89+
project = get_object_or_404(Project, slug=project_slug)
90+
setattr(self, cache_key, project)
91+
92+
return project
93+
94+
def _get_version(self):
95+
cache_key = '_cached_version'
96+
version = getattr(self, cache_key, None)
97+
98+
if not version:
99+
version_slug = self.request.GET.get('version', None)
100+
project = self._get_project()
101+
version = get_object_or_404(
102+
project.versions.all(),
103+
slug=version_slug,
104+
)
105+
setattr(self, cache_key, version)
106+
107+
return version
108+
68109
def get_queryset(self):
69110
"""
70111
Return Elasticsearch DSL Search object instead of Django Queryset.
@@ -78,7 +119,8 @@ def get_queryset(self):
78119
query = self.request.query_params.get('q', '')
79120
kwargs = {'filter_by_user': False, 'filters': {}}
80121
kwargs['filters']['project'] = [p.slug for p in self.get_all_projects()]
81-
kwargs['filters']['version'] = self.request.query_params.get('version')
122+
kwargs['filters']['version'] = self._get_version().slug
123+
# Check to avoid searching all projects in case project is empty.
82124
if not kwargs['filters']['project']:
83125
log.info("Unable to find a project to search")
84126
return HTMLFile.objects.none()
@@ -118,19 +160,26 @@ def get_serializer_context(self):
118160

119161
def get_all_projects(self):
120162
"""
121-
Return a list containing the project itself and all its subprojects.
122-
123-
The project slug is retrieved from ``project`` query param.
163+
Return a list of the project itself and all its subprojects the user has permissions over.
124164
125165
:rtype: list
126-
127-
:raises: Http404 if project is not found
128166
"""
129-
project_slug = self.request.query_params.get('project')
130-
version_slug = self.request.query_params.get('version')
131-
all_projects = utils.get_project_list_or_404(
132-
project_slug=project_slug, user=self.request.user, version_slug=version_slug,
167+
main_version = self._get_version()
168+
main_project = self._get_project()
169+
170+
subprojects = Project.objects.filter(
171+
superprojects__parent_id=main_project.id,
133172
)
173+
all_projects = []
174+
for project in list(subprojects) + [main_project]:
175+
version = (
176+
Version.internal
177+
.public(user=self.request.user, project=project)
178+
.filter(slug=main_version.slug)
179+
.first()
180+
)
181+
if version:
182+
all_projects.append(version.project)
134183
return all_projects
135184

136185
def get_all_projects_url(self):
@@ -151,7 +200,7 @@ def get_all_projects_url(self):
151200
:rtype: dict
152201
"""
153202
all_projects = self.get_all_projects()
154-
version_slug = self.request.query_params.get('version')
203+
version_slug = self._get_version().slug
155204
projects_url = {}
156205
for project in all_projects:
157206
projects_url[project.slug] = project.get_docs_url(version_slug=version_slug)
@@ -162,8 +211,8 @@ def list(self, request, *args, **kwargs):
162211

163212
response = super().list(request, *args, **kwargs)
164213

165-
project_slug = self.request.query_params.get('project', None)
166-
version_slug = self.request.query_params.get('version', None)
214+
project_slug = self._get_project().slug
215+
version_slug = self._get_version().slug
167216
total_results = response.data.get('count', 0)
168217
time = timezone.now()
169218

readthedocs/search/tests/test_api.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from unittest import mock
23

34
import pytest
45
from django.urls import reverse
@@ -7,6 +8,7 @@
78
from readthedocs.builds.models import Version
89
from readthedocs.projects.constants import PUBLIC
910
from readthedocs.projects.models import HTMLFile, Project
11+
from readthedocs.search.api import PageSearchAPIView
1012
from readthedocs.search.documents import PageDocument
1113
from readthedocs.search.tests.utils import (
1214
DOMAIN_FIELDS,
@@ -199,11 +201,17 @@ def test_doc_search_pagination(self, api_client, project):
199201
assert len(resp.data['results']) == 5
200202

201203
def test_doc_search_without_parameters(self, api_client, project):
202-
"""Hitting Document Search endpoint without query parameters should return error"""
204+
"""Hitting Document Search endpoint without project and version should return 404."""
203205
resp = self.get_search(api_client, {})
206+
assert resp.status_code == 404
207+
208+
def test_doc_search_without_query(self, api_client, project):
209+
"""Hitting Document Search endpoint without a query should return error."""
210+
resp = self.get_search(
211+
api_client, {'project': project.slug, 'version': project.versions.first().slug})
204212
assert resp.status_code == 400
205213
# Check error message is there
206-
assert sorted(['q', 'project', 'version']) == sorted(resp.data.keys())
214+
assert 'q' in resp.data.keys()
207215

208216
def test_doc_search_subprojects(self, api_client, all_projects):
209217
"""Test Document search return results from subprojects also"""
@@ -255,6 +263,20 @@ def test_doc_search_unexisting_version(self, api_client, project):
255263
'version': version,
256264
}
257265
resp = self.get_search(api_client, search_params)
266+
assert resp.status_code == 404
267+
268+
@mock.patch.object(PageSearchAPIView, 'get_all_projects', list)
269+
def test_get_all_projects_returns_empty_results(self, api_client, project):
270+
"""If there is a case where `get_all_projects` returns empty, we could be querying all projects."""
271+
272+
# `documentation` word is present both in `kuma` and `docs` files
273+
# and not in `pipeline`, so search with this phrase but filter through project
274+
search_params = {
275+
'q': 'documentation',
276+
'project': 'docs',
277+
'version': 'latest'
278+
}
279+
resp = self.get_search(api_client, search_params)
258280
assert resp.status_code == 200
259281

260282
data = resp.data['results']

readthedocs/search/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,27 +82,6 @@ def remove_indexed_files(model, project_slug, version_slug=None, build_id=None):
8282
log.exception('Unable to delete a subset of files. Continuing.')
8383

8484

85-
# TODO: Rewrite all the views using this in Class Based View,
86-
# and move this function to a mixin
87-
def get_project_list_or_404(project_slug, user, version_slug=None):
88-
"""
89-
Return list of project and its subprojects.
90-
91-
It filters by Version privacy instead of Project privacy,
92-
so we can support public versions on private projects.
93-
"""
94-
project_list = []
95-
main_project = get_object_or_404(Project, slug=project_slug)
96-
subprojects = Project.objects.filter(superprojects__parent_id=main_project.id)
97-
for project in list(subprojects) + [main_project]:
98-
version = Version.internal.public(user).filter(
99-
project__slug=project.slug, slug=version_slug
100-
)
101-
if version.exists():
102-
project_list.append(version.first().project)
103-
return project_list
104-
105-
10685
def _get_index(indices, index_name):
10786
"""
10887
Get Index from all the indices.

0 commit comments

Comments
 (0)