diff --git a/readthedocs/search/api.py b/readthedocs/search/api.py index af771cb8ff6..110af55e535 100644 --- a/readthedocs/search/api.py +++ b/readthedocs/search/api.py @@ -1,4 +1,4 @@ -from functools import lru_cache, namedtuple +from functools import lru_cache, namedtuple, cached_property from math import ceil import structlog @@ -15,6 +15,8 @@ from readthedocs.projects.models import Feature, Project from readthedocs.search import tasks from readthedocs.search.faceted_search import PageSearch +from readthedocs.search.query import SearchQueryParser +from readthedocs.search.backends import BackendV1, BackendV2 from .serializers import PageSearchSerializer, ProjectData, VersionData @@ -117,10 +119,83 @@ def paginate_queryset(self, queryset, request, view=None): return result -class PageSearchAPIView(CDNCacheTagsMixin, GenericAPIView): +class SearchAPIBase(GenericAPIView): + + http_method_names = ['get'] + pagination_class = SearchPagination + search_backend = None + + def _validate_query_params(self): + """ + Validate all query params that are passed in the request. + + :raises: ValidationError if one of them is missing. + """ + raise NotImplementedError + + def _get_search_query(self): + raise NotImplementedError + + def _use_advanced_query(self): + raise NotImplementedError + + def _record_query(self, response): + raise NotImplementedError + + def get_queryset(self): + """ + Returns an Elasticsearch DSL search object or an iterator. + + .. note:: + + Calling ``list(search)`` over an DSL search object is the same as + calling ``search.execute().hits``. This is why an DSL search object + is compatible with DRF's paginator. + """ + backend = self.search_backend() + projects = { + project.slug: version.slug + for project, version in self._get_projects_to_search() + } + # Check to avoid searching all projects in case it's empty. + if not projects: + log.info('Unable to find a version to search') + return [] + + query = self._get_search_query() + queryset = PageSearch( + query=query, + projects=projects, + aggregate_results=False, + use_advanced_query=self._use_advanced_query(), + ) + return queryset + + def get_serializer_context(self): + context = super().get_serializer_context() + context['projects_data'] = self._get_all_projects_data() + return context + + def get(self, request, *args, **kwargs): + self._validate_query_params() + result = self.list() + self._record_query(result) + return result + + def list(self): + """List the results using pagination.""" + queryset = self.get_queryset() + page = self.paginator.paginate_queryset( + queryset, self.request, view=self, + ) + serializer = self.get_serializer(page, many=True) + return self.paginator.get_paginated_response(serializer.data) + + +class PageSearchAPIView(CDNCacheTagsMixin, SearchAPIBase): """ - Server side search API. + Server side search API V3. Required query parameters: @@ -128,12 +203,10 @@ class PageSearchAPIView(CDNCacheTagsMixin, GenericAPIView): - **project**: Project to search. - **version**: Version to search. - Check our [docs](https://docs.readthedocs.io/en/stable/server-side-search.html#api) for more information. + Check our [docs](https://docs.readthedocs.io/page/server-side-search.html#api) for more information. """ # noqa - http_method_names = ['get'] permission_classes = [IsAuthorizedToViewVersion] - pagination_class = SearchPagination serializer_class = PageSearchSerializer project_cache_tag = 'rtd-search' @@ -154,15 +227,6 @@ def _get_version(self): return version def _validate_query_params(self): - """ - Validate all required query params are passed on the request. - - Query params required are: ``q``, ``project`` and ``version``. - - :rtype: None - - :raises: ValidationError if one of them is missing. - """ errors = {} required_query_params = {'q', 'project', 'version'} request_params = set(self.request.query_params.keys()) @@ -173,9 +237,47 @@ def _validate_query_params(self): raise ValidationError(errors) @lru_cache(maxsize=1) + def _get_projects_to_search(self): + main_version = self._get_version() + main_project = self._get_project() + + if not self._has_permission(self.request.user, main_version): + return {} + + projects = [(main_project, main_version)] + projects.extend(self._get_subprojects(main_project, version_slug=main_version.slug)) + return projects + + def _get_search_query(self): + return self.request.query_params["q"] + + def _record_query(self, response): + project_slug = self._get_project().slug + version_slug = self._get_version().slug + total_results = response.data.get('count', 0) + time = timezone.now() + + query = self._get_search_query().lower().strip() + + # Record the query with a celery task + tasks.record_search_query.delay( + project_slug, + version_slug, + query, + total_results, + time.isoformat(), + ) + + def _use_advanced_query(self): + main_project = self._get_project() + return not main_project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH) + + +class ProjectsToSearchMixin: + def _get_all_projects_data(self): """ - Return a dictionary of the project itself and all its subprojects. + Return a dictionary of all projects/versions that will be used in the search. Example: @@ -198,41 +300,12 @@ def _get_all_projects_data(self): ), } - .. note:: The response is cached into the instance. - - :rtype: A dictionary of project slugs mapped to a `VersionData` object. + :returns: A dictionary of project slugs mapped to a `VersionData` object. """ - main_version = self._get_version() - main_project = self._get_project() - - if not self._has_permission(self.request.user, main_version): - return {} - projects_data = { - main_project.slug: self._get_project_data(main_project, main_version), + project.slug: self._get_project_data(project, version) + for project, version in self._get_projects_to_search() } - - subprojects = Project.objects.filter(superprojects__parent_id=main_project.id) - for subproject in subprojects: - version = self._get_project_version( - project=subproject, - version_slug=main_version.slug, - include_hidden=False, - ) - - # Fallback to the default version of the subproject. - if not version and subproject.default_version: - version = self._get_project_version( - project=subproject, - version_slug=subproject.default_version, - include_hidden=False, - ) - - if version and self._has_permission(self.request.user, version): - projects_data[subproject.slug] = self._get_project_data( - subproject, version - ) - return projects_data def _get_project_data(self, project, version): @@ -248,104 +321,59 @@ def _get_project_data(self, project, version): version=version_data, ) - def _get_project_version(self, project, version_slug, include_hidden=True): - """ - Get a version from a given project. - - :param project: A `Project` object. - :param version_slug: The version slug. - :param include_hidden: If hidden versions should be considered. - """ - return ( - Version.internal - .public( - user=self.request.user, - project=project, - only_built=True, - include_hidden=include_hidden, - ) - .filter(slug=version_slug) - .first() - ) - def _has_permission(self, user, version): - """ - Check if `user` is authorized to access `version`. +class SearchAPIV3(SearchAPIBase, ProjectsToSearchMixin): - The queryset from `_get_subproject_version` already filters public - projects. This is mainly to be overridden in .com to make use of - the auth backends in the proxied API. - """ - return True - - def _get_search_query(self): - return self.request.query_params["q"] - - def _record_query(self, response): - project_slug = self._get_project().slug - version_slug = self._get_version().slug - total_results = response.data.get('count', 0) - time = timezone.now() + """ + Server side search API V3. - query = self._get_search_query().lower().strip() + Required query parameters: - # Record the query with a celery task - tasks.record_search_query.delay( - project_slug, - version_slug, - query, - total_results, - time.isoformat(), - ) + - **q**: Search term. - def _use_advanced_query(self): - main_project = self._get_project() - return not main_project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH) + Check our [docs](https://docs.readthedocs.io/page/server-side-search.html#api) for more information. + """ # noqa - def get_queryset(self): - """ - Returns an Elasticsearch DSL search object or an iterator. + serializer_class = PageSearchSerializer - .. note:: + def get_view_name(self): + return "Search API V3" - Calling ``list(search)`` over an DSL search object is the same as - calling ``search.execute().hits``. This is why an DSL search object - is compatible with DRF's paginator. - """ - projects = { - project: project_data.version.slug - for project, project_data in self._get_all_projects_data().items() - } - # Check to avoid searching all projects in case it's empty. - if not projects: - log.info('Unable to find a version to search') - return [] + def _validate_query_params(self): + if "q" not in self.request.query_params: + raise ValidationError({"q": [_("This query parameter is required")]}) - query = self._get_search_query() - queryset = PageSearch( - query=query, - projects=projects, - aggregate_results=False, - use_advanced_query=self._use_advanced_query(), - ) - return queryset + def _get_search_query(self): + return self._parser.query - def get_serializer_context(self): - context = super().get_serializer_context() - context['projects_data'] = self._get_all_projects_data() - return context + def _use_advanced_query(self): + # TODO: we should make this a parameter in the API, + # we are checking if the first project has this feature for now. + project = self._get_projects_to_search()[0][0] + return not project.has_feature(Feature.DEFAULT_TO_FUZZY_SEARCH) - def get(self, request, *args, **kwargs): - self._validate_query_params() - result = self.list() - self._record_query(result) - return result + def _record_query(self, response): + total_results = response.data.get('count', 0) + time = timezone.now() + query = self._get_search_query().lower().strip() + # NOTE: I think this may be confusing, + # since the number of results is the total + # of searching on all projects, this specific project + # could have had 0 results. + for project, version in self._get_projects_to_search(): + tasks.record_search_query.delay( + project.slug, + version.slug, + query, + total_results, + time.isoformat(), + ) def list(self): - """List the results using pagination.""" - queryset = self.get_queryset() - page = self.paginator.paginate_queryset( - queryset, self.request, view=self, - ) - serializer = self.get_serializer(page, many=True) - return self.paginator.get_paginated_response(serializer.data) + response = super().list() + response.data["projects"] = [ + [project.slug, version.slug] + for project, version in self._get_projects_to_search() + ] + response.data["query"] = self._get_search_query() + return response diff --git a/readthedocs/search/backends.py b/readthedocs/search/backends.py new file mode 100644 index 00000000000..167477c5095 --- /dev/null +++ b/readthedocs/search/backends.py @@ -0,0 +1,224 @@ +from readthedocs.search.faceted_search import PageSearch +from readthedocs.builds.models import Version +from itertools import islice +from functools import lru_cache, cached_property +from readthedocs.projects.models import Project +from readthedocs.search.query import SearchQueryParser + +class Backend: + + max_projects = 100 + + def search(self, **kwargs): + raise NotImplementedError + + @lru_cache(maxsize=1) + def projects(self): + return list(islice(self._get_projects_to_search(), self.max_projects)) + + def _get_projects_to_search(self): + raise NotImplementedError + + def _get_projects_from_user(self): + for project in Project.objects.for_user(user=self.request.user): + version = self._get_project_version( + project=project, + version_slug=project.default_version, + include_hidden=False, + ) + if version and self._has_permission(self.request.user, version): + yield project, version + + def _get_subprojects(self, project, version_slug=None): + """ + Get a tuple project/version of all subprojects of `project`. + + If `version_slug` doesn't match a version of the subproject, + the default version will be used. + If `version_slug` is None, we will always use the default version. + """ + subprojects = Project.objects.filter(superprojects__parent=project) + for subproject in subprojects: + version = None + if version_slug: + version = self._get_project_version( + project=subproject, + version_slug=version_slug, + include_hidden=False, + ) + + # Fallback to the default version of the subproject. + if not version and subproject.default_version: + version = self._get_project_version( + project=subproject, + version_slug=subproject.default_version, + include_hidden=False, + ) + + if version and self._has_permission(self.request.user, version): + yield project, version + + def _has_permission(self, user, version): + """ + Check if `user` is authorized to access `version`. + + The queryset from `_get_project_version` already filters public + projects. This is mainly to be overridden in .com to make use of + the auth backends in the proxied API. + """ + return True + + def _get_project_version(self, project, version_slug, include_hidden=True): + """ + Get a version from a given project. + + :param project: A `Project` object. + :param version_slug: The version slug. + :param include_hidden: If hidden versions should be considered. + """ + return ( + Version.internal + .public( + user=self.request.user, + project=project, + only_built=True, + include_hidden=include_hidden, + ) + .filter(slug=version_slug) + .first() + ) + + +class BackendV1(Backend): + + def __init__(self, *, request, query, project, version): + self.request = request + self.project = project + self.version = version + self.query = query + + def search(self, **kwargs): + projects = { + project.slug: version.slug + for project, version in self.projects + } + if not projects: + return None + + return PageSearch( + query=self.query, + projects=projects, + **kwargs, + ) + + def _get_projects_to_search(self): + if not self._has_permission(self.request.user, self.version): + return + + yield self.project, self.version + + yield from self._get_subprojects(self.project, version_slug=self.version.slug) + + +class BackendV2(Backend): + + def __init__(self, *, request, query, allow_search_all=False): + self.request = request + self.query = query + self.allow_search_all = allow_search_all + + @property + def _has_arguments(self): + return any(self.parser.arguments.values()) + + def _get_default_projects(self): + if self.allow_search_all: + # Default to search all. + return [] + return self._get_projects_from_user() + + def search(self, **kwargs): + projects = { + project.slug: version.slug + for project, version in self.projects + } + # If the search is done without projects, ES will search on all projects. + # If we don't have projects and the user provided arguments, + # it means we don't have anything to search on (no results). + # Or if we don't have projects and we don't allow searching all, + # we also just return. + if not projects and (self._has_arguments or not self.allow_search_all): + return None + + queryset = PageSearch( + query=self.parser.query, + projects=projects, + **kwargs, + ) + return queryset + + @cached_property + def parser(self): + parser = SearchQueryParser(self.query) + parser.parse() + return parser + + def _get_projects_to_search(self): + if not self._has_arguments: + return self._get_default_projects() + + for value in self.parser.arguments["project"]: + project, version = self._get_project_and_version(value) + if version and self._has_permission(self.request.user, version): + yield project, version + + for value in self.parser.arguments['subprojects']: + project, version = self._get_project_and_version(value) + + # Add the project itself. + if version and self._has_permission(self.request.user, version): + yield project, version + + # If the user didn't provide a version, version_slug will be `None`, + # and we add all subprojects with their default version, + # otherwise we will add all projects that match the given version. + _, version_slug = self._split_project_and_version(value) + if project: + yield from self._get_subprojects( + project=project, + version_slug=version_slug, + ) + + # Add all projects the user has access to. + if self.parser.arguments["user"] == "@me": + yield from self._get_projects_from_user() + + def _split_project_and_version(self, term): + """ + Split a term of the form ``{project}/{version}``. + + :returns: A tuple of project and version. + If the version part isn't found, `None` will be returned in its place. + """ + parts = term.split("/", maxsplit=1) + if len(parts) > 1: + return parts + return parts[0], None + + def _get_project_and_version(self, value): + project_slug, version_slug = self._split_project_and_version(value) + project = Project.objects.filter(slug=project_slug).first() + if not project: + return None, None + + if not version_slug: + version_slug = project.default_version + + if version_slug: + version = self._get_project_version( + project=project, + version_slug=version_slug, + ) + return project, version + + return None, None diff --git a/readthedocs/search/faceted_search.py b/readthedocs/search/faceted_search.py index 4dfb8b0622f..59f6defbe0b 100644 --- a/readthedocs/search/faceted_search.py +++ b/readthedocs/search/faceted_search.py @@ -4,7 +4,6 @@ from django.conf import settings from elasticsearch import Elasticsearch from elasticsearch_dsl import FacetedSearch, TermsFacet -from elasticsearch_dsl.faceted_search import NestedFacet from elasticsearch_dsl.query import ( Bool, FunctionScore, @@ -20,8 +19,6 @@ log = structlog.get_logger(__name__) -ALL_FACETS = ['project', 'version', 'role_name', 'language'] - class RTDFacetedSearch(FacetedSearch): @@ -268,11 +265,6 @@ def query(self, search, query): class PageSearch(RTDFacetedSearch): facets = { 'project': TermsFacet(field='project'), - 'version': TermsFacet(field='version'), - 'role_name': NestedFacet( - 'domains', - TermsFacet(field='domains.role_name') - ), } doc_types = [PageDocument] index = PageDocument._index._name @@ -365,18 +357,6 @@ def _get_nested_query(self, *, query, path, fields): for field in fields ] - # The ``post_filter`` filter will only filter documents - # at the parent level (domains is a nested document), - # resulting in results with domains that don't match the current - # role_name being filtered, so we need to force filtering by role_name - # on the ``domains`` document here. See #8268. - # TODO: We should use a flattened document instead - # to avoid this kind of problems and have faster queries. - role_name = self.filter_values.get('role_name') - if path == 'domains' and role_name: - role_name_query = Bool(must=Terms(**{'domains.role_name': role_name})) - bool_query = Bool(must=[role_name_query, bool_query]) - highlight = dict( self._highlight_options, fields={ diff --git a/readthedocs/search/query.py b/readthedocs/search/query.py new file mode 100644 index 00000000000..76a582c6592 --- /dev/null +++ b/readthedocs/search/query.py @@ -0,0 +1,75 @@ +class Token: + + pass + + +class TextToken(Token): + + def __init__(self, text): + self.text = text + + +class ArgumentToken(Token): + + def __init__(self, *, name, value, type): + self.name = name + self.value = value + self.type = type + + +class SearchQueryParser: + + allowed_arguments = { + "project": list, + "subprojects": list, + "user": str, + } + + def __init__(self, query): + self._query = query + self.query = "" + self.arguments = {} + + def parse(self): + tokens = ( + self._get_token(text) + for text in self._query.split() + ) + query = [] + arguments = { + name: type() + for name, type in self.allowed_arguments.items() + } + for token in tokens: + if isinstance(token, TextToken): + query.append(token.text) + elif isinstance(token, ArgumentToken): + if token.type == str: + arguments[token.name] = token.value + elif token.type == list: + arguments[token.name].append(token.value) + else: + raise ValueError(f"Invalid argument type {token.type}") + else: + raise ValueError("Invalid node") + + self.query = self._unescape(" ".join(query)) + self.arguments = arguments + + def _get_token(self, text): + result = text.split(":", maxsplit=1) + if len(result) < 2: + return TextToken(text) + + name, value = result + if name in self.allowed_arguments: + return ArgumentToken( + name=name, + value=value, + type=self.allowed_arguments[name], + ) + + return TextToken(text) + + def _unescape(self, text): + return text.replace("\\:", ":") diff --git a/readthedocs/search/views.py b/readthedocs/search/views.py index 758eee067de..b8c4427f767 100644 --- a/readthedocs/search/views.py +++ b/readthedocs/search/views.py @@ -4,12 +4,12 @@ from django.conf import settings from django.shortcuts import get_object_or_404, render +from readthedocs.search.backends import BackendV2 from django.views import View from readthedocs.builds.constants import LATEST from readthedocs.projects.models import Feature, Project from readthedocs.search.faceted_search import ( - ALL_FACETS, PageSearch, ProjectSearch, ) @@ -21,6 +21,8 @@ VersionData, ) +from readthedocs.search.api import ProjectsToSearchMixin + log = structlog.get_logger(__name__) UserInput = collections.namedtuple( @@ -28,10 +30,7 @@ ( 'query', 'type', - 'project', - 'version', 'language', - 'role_name', ), ) @@ -47,7 +46,7 @@ def _search(self, *, user_input, projects, use_advanced_query): return [], {} filters = {} - for avail_facet in ALL_FACETS: + for avail_facet in ['language']: value = getattr(user_input, avail_facet, None) if value: filters[avail_facet] = value @@ -145,6 +144,7 @@ def get(self, request, project_slug): 'results': results, 'facets': facets, 'project_obj': project_obj, + 'search_query': self._parser.query }) return render( @@ -154,61 +154,49 @@ def get(self, request, project_slug): ) -class GlobalSearchView(SearchViewBase): +class GlobalSearchView(SearchViewBase, ProjectsToSearchMixin): """ Global search enabled for logged out users and anyone using the dashboard. Query params: - - q: search term - - type: type of document to search (project or file) - - project: project to filter by - - language: project language to filter by - - version: version to filter by - - role_name: sphinx role to filter by + - q: Search query + - type: Type of document to search (project or file) """ def get(self, request): user_input = UserInput( - query=request.GET.get('q'), + query=request.GET.get('q', ""), type=request.GET.get('type', 'project'), - project=request.GET.get('project'), - version=request.GET.get('version', LATEST), language=request.GET.get('language'), - role_name=request.GET.get('role_name'), ) - projects = [] - # If we allow private projects, - # we only search on projects the user belongs or have access to. - if settings.ALLOW_PRIVATE_REPOS: - projects = list( - Project.objects.for_user(request.user) - .values_list('slug', flat=True) - ) - - # Make sure we always have projects to filter by if we allow private projects. - if settings.ALLOW_PRIVATE_REPOS and not projects: - results, facets = [], {} - else: - results, facets = self._search( - user_input=user_input, - projects=projects, - use_advanced_query=True, - ) + backend = BackendV2( + request=request, + query=user_input.query, + allow_search_all=not settings.ALLOW_PRIVATE_REPOS, + ) + search = backend.search() + results = [] + facets = {} + if search: + results = search[:self.max_search_results].execute() + facets = results.facets serializers = { 'project': ProjectSearchSerializer, 'file': PageSearchSerializer, } - serializer = serializers.get(user_input.type, ProjectSearchSerializer) + serializer = serializers.get("file", ProjectSearchSerializer) results = serializer(results, many=True).data template_context = user_input._asdict() template_context.update({ 'results': results, 'facets': facets, + "parser": backend.parser, + 'search_query': backend.parser.query, }) return render( diff --git a/readthedocs/templates/search/elastic_search.html b/readthedocs/templates/search/elastic_search.html index 8773b353a3b..5d169e4aa0a 100644 --- a/readthedocs/templates/search/elastic_search.html +++ b/readthedocs/templates/search/elastic_search.html @@ -36,29 +36,14 @@
{% trans 'Object Type' %}
{% if facets.project and not project_obj %}
{% trans 'Projects' %}
- {% for name, count, selected in facets.project %} -
  • - {% if project == name %} - {{ name }} - {% else %} - {{ name }} - {% endif %} - ({{ count }}) - -
  • - {% endfor %} -
    - {% endif %} - {% if facets.version %} -
    {% trans 'Version' %}
    - {% for name, count, selected in facets.version %} -
  • - {% if version == name %} - {{ name }} - {% else %} - {{ name }} - {% endif %} +
  • + Search all +
  • + + {% for name, count, selected in facets.project %} +
  • + {{ name }} ({{ count }})
  • @@ -135,7 +120,7 @@

    {% trans 'Search' %}

    - {% blocktrans with count=results.hits.total query=query|default:"" %} + {% blocktrans with count=results.hits.total query=search_query|default:"" %} {{ count }} results for `{{ query }}` {% endblocktrans %}

    diff --git a/readthedocs/urls.py b/readthedocs/urls.py index c1ffb85c534..06de64420aa 100644 --- a/readthedocs/urls.py +++ b/readthedocs/urls.py @@ -15,7 +15,7 @@ do_not_track, server_error_500, ) -from readthedocs.search.api import PageSearchAPIView +from readthedocs.search.api import PageSearchAPIView, SearchAPIV3 from readthedocs.search.views import GlobalSearchView admin.autodiscover() @@ -101,6 +101,7 @@ ), re_path(r'^api/v3/', include('readthedocs.api.v3.urls')), re_path(r'^api/v3/embed/', include('readthedocs.embed.v3.urls')), + re_path(r'^api/v3/search/$', SearchAPIV3.as_view(), name='search_api_v3'), ] i18n_urls = [