Skip to content

Search: API V3 #9615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 164 additions & 136 deletions readthedocs/search/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import lru_cache, namedtuple
from functools import lru_cache, namedtuple, cached_property
from math import ceil

import structlog
Expand All @@ -15,6 +15,8 @@
from readthedocs.projects.models import Feature, Project
from readthedocs.search import tasks
from readthedocs.search.faceted_search import PageSearch
from readthedocs.search.query import SearchQueryParser
from readthedocs.search.backends import BackendV1, BackendV2

from .serializers import PageSearchSerializer, ProjectData, VersionData

Expand Down Expand Up @@ -117,23 +119,94 @@ def paginate_queryset(self, queryset, request, view=None):
return result


class PageSearchAPIView(CDNCacheTagsMixin, GenericAPIView):
class SearchAPIBase(GenericAPIView):

http_method_names = ['get']
pagination_class = SearchPagination
search_backend = None

def _validate_query_params(self):
"""
Validate all query params that are passed in the request.

:raises: ValidationError if one of them is missing.
"""
raise NotImplementedError

def _get_search_query(self):
raise NotImplementedError

def _use_advanced_query(self):
raise NotImplementedError

def _record_query(self, response):
raise NotImplementedError

def get_queryset(self):
"""
Returns an Elasticsearch DSL search object or an iterator.

.. note::

Calling ``list(search)`` over an DSL search object is the same as
calling ``search.execute().hits``. This is why an DSL search object
is compatible with DRF's paginator.
"""
backend = self.search_backend()
projects = {
project.slug: version.slug
for project, version in self._get_projects_to_search()
}
# Check to avoid searching all projects in case it's empty.
if not projects:
log.info('Unable to find a version to search')
return []

query = self._get_search_query()
queryset = PageSearch(
query=query,
projects=projects,
aggregate_results=False,
use_advanced_query=self._use_advanced_query(),
)
return queryset

def get_serializer_context(self):
context = super().get_serializer_context()
context['projects_data'] = self._get_all_projects_data()
return context

def get(self, request, *args, **kwargs):
self._validate_query_params()
result = self.list()
self._record_query(result)
return result

def list(self):
"""List the results using pagination."""
queryset = self.get_queryset()
page = self.paginator.paginate_queryset(
queryset, self.request, view=self,
)
serializer = self.get_serializer(page, many=True)
return self.paginator.get_paginated_response(serializer.data)


class PageSearchAPIView(CDNCacheTagsMixin, SearchAPIBase):

"""
Server side search API.
Server side search API V3.

Required query parameters:

- **q**: Search term.
- **project**: Project to search.
- **version**: Version to search.

Check our [docs](https://docs.readthedocs.io/en/stable/server-side-search.html#api) for more information.
Check our [docs](https://docs.readthedocs.io/page/server-side-search.html#api) for more information.
""" # noqa

http_method_names = ['get']
permission_classes = [IsAuthorizedToViewVersion]
pagination_class = SearchPagination
serializer_class = PageSearchSerializer
project_cache_tag = 'rtd-search'

Expand All @@ -154,15 +227,6 @@ def _get_version(self):
return version

def _validate_query_params(self):
"""
Validate all required query params are passed on the request.

Query params required are: ``q``, ``project`` and ``version``.

:rtype: None

:raises: ValidationError if one of them is missing.
"""
errors = {}
required_query_params = {'q', 'project', 'version'}
request_params = set(self.request.query_params.keys())
Expand All @@ -173,9 +237,47 @@ def _validate_query_params(self):
raise ValidationError(errors)

@lru_cache(maxsize=1)
def _get_projects_to_search(self):
main_version = self._get_version()
main_project = self._get_project()

if not self._has_permission(self.request.user, main_version):
return {}

projects = [(main_project, main_version)]
projects.extend(self._get_subprojects(main_project, version_slug=main_version.slug))
return projects

def _get_search_query(self):
return self.request.query_params["q"]

def _record_query(self, response):
project_slug = self._get_project().slug
version_slug = self._get_version().slug
total_results = response.data.get('count', 0)
time = timezone.now()

query = self._get_search_query().lower().strip()

# Record the query with a celery task
tasks.record_search_query.delay(
project_slug,
version_slug,
query,
total_results,
time.isoformat(),
)

def _use_advanced_query(self):
main_project = self._get_project()
return not main_project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH)


class ProjectsToSearchMixin:

def _get_all_projects_data(self):
"""
Return a dictionary of the project itself and all its subprojects.
Return a dictionary of all projects/versions that will be used in the search.

Example:

Expand All @@ -198,41 +300,12 @@ def _get_all_projects_data(self):
),
}

.. note:: The response is cached into the instance.

:rtype: A dictionary of project slugs mapped to a `VersionData` object.
:returns: A dictionary of project slugs mapped to a `VersionData` object.
"""
main_version = self._get_version()
main_project = self._get_project()

if not self._has_permission(self.request.user, main_version):
return {}

projects_data = {
main_project.slug: self._get_project_data(main_project, main_version),
project.slug: self._get_project_data(project, version)
for project, version in self._get_projects_to_search()
}

subprojects = Project.objects.filter(superprojects__parent_id=main_project.id)
for subproject in subprojects:
version = self._get_project_version(
project=subproject,
version_slug=main_version.slug,
include_hidden=False,
)

# Fallback to the default version of the subproject.
if not version and subproject.default_version:
version = self._get_project_version(
project=subproject,
version_slug=subproject.default_version,
include_hidden=False,
)

if version and self._has_permission(self.request.user, version):
projects_data[subproject.slug] = self._get_project_data(
subproject, version
)

return projects_data

def _get_project_data(self, project, version):
Expand All @@ -248,104 +321,59 @@ def _get_project_data(self, project, version):
version=version_data,
)

def _get_project_version(self, project, version_slug, include_hidden=True):
"""
Get a version from a given project.

:param project: A `Project` object.
:param version_slug: The version slug.
:param include_hidden: If hidden versions should be considered.
"""
return (
Version.internal
.public(
user=self.request.user,
project=project,
only_built=True,
include_hidden=include_hidden,
)
.filter(slug=version_slug)
.first()
)

def _has_permission(self, user, version):
"""
Check if `user` is authorized to access `version`.
class SearchAPIV3(SearchAPIBase, ProjectsToSearchMixin):

The queryset from `_get_subproject_version` already filters public
projects. This is mainly to be overridden in .com to make use of
the auth backends in the proxied API.
"""
return True

def _get_search_query(self):
return self.request.query_params["q"]

def _record_query(self, response):
project_slug = self._get_project().slug
version_slug = self._get_version().slug
total_results = response.data.get('count', 0)
time = timezone.now()
"""
Server side search API V3.

query = self._get_search_query().lower().strip()
Required query parameters:

# Record the query with a celery task
tasks.record_search_query.delay(
project_slug,
version_slug,
query,
total_results,
time.isoformat(),
)
- **q**: Search term.

def _use_advanced_query(self):
main_project = self._get_project()
return not main_project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH)
Check our [docs](https://docs.readthedocs.io/page/server-side-search.html#api) for more information.
""" # noqa

def get_queryset(self):
"""
Returns an Elasticsearch DSL search object or an iterator.
serializer_class = PageSearchSerializer

.. note::
def get_view_name(self):
return "Search API V3"

Calling ``list(search)`` over an DSL search object is the same as
calling ``search.execute().hits``. This is why an DSL search object
is compatible with DRF's paginator.
"""
projects = {
project: project_data.version.slug
for project, project_data in self._get_all_projects_data().items()
}
# Check to avoid searching all projects in case it's empty.
if not projects:
log.info('Unable to find a version to search')
return []
def _validate_query_params(self):
if "q" not in self.request.query_params:
raise ValidationError({"q": [_("This query parameter is required")]})

query = self._get_search_query()
queryset = PageSearch(
query=query,
projects=projects,
aggregate_results=False,
use_advanced_query=self._use_advanced_query(),
)
return queryset
def _get_search_query(self):
return self._parser.query

def get_serializer_context(self):
context = super().get_serializer_context()
context['projects_data'] = self._get_all_projects_data()
return context
def _use_advanced_query(self):
# TODO: we should make this a parameter in the API,
# we are checking if the first project has this feature for now.
project = self._get_projects_to_search()[0][0]
return not project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH)

def get(self, request, *args, **kwargs):
self._validate_query_params()
result = self.list()
self._record_query(result)
return result
def _record_query(self, response):
total_results = response.data.get('count', 0)
time = timezone.now()
query = self._get_search_query().lower().strip()
# NOTE: I think this may be confusing,
# since the number of results is the total
# of searching on all projects, this specific project
# could have had 0 results.
for project, version in self._get_projects_to_search():
tasks.record_search_query.delay(
project.slug,
version.slug,
query,
total_results,
time.isoformat(),
)

def list(self):
"""List the results using pagination."""
queryset = self.get_queryset()
page = self.paginator.paginate_queryset(
queryset, self.request, view=self,
)
serializer = self.get_serializer(page, many=True)
return self.paginator.get_paginated_response(serializer.data)
response = super().list()
response.data["projects"] = [
[project.slug, version.slug]
for project, version in self._get_projects_to_search()
]
response.data["query"] = self._get_search_query()
return response
Loading