diff --git a/elasticsearch_dsl/_async/__init__.py b/elasticsearch_dsl/_async/__init__.py new file mode 100644 index 000000000..2a87d183f --- /dev/null +++ b/elasticsearch_dsl/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py new file mode 100644 index 000000000..98dc56893 --- /dev/null +++ b/elasticsearch_dsl/_async/document.py @@ -0,0 +1,477 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc + +from fnmatch import fnmatch + +from elasticsearch.exceptions import NotFoundError, RequestError +from six import add_metaclass, iteritems + +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException +from elasticsearch_dsl.field import Field +from elasticsearch_dsl.index import Index +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge + +from .utils import ensure_async_connection + + +class MetaField(object): + def __init__(self, *args, **kwargs): + self.args, self.kwargs = args, kwargs + + +class DocumentMeta(type): + def __new__(cls, name, bases, attrs): + # DocumentMeta filters attrs in place + attrs["_doc_type"] = DocumentOptions(name, bases, attrs) + return super(DocumentMeta, cls).__new__(cls, name, bases, attrs) + + +class IndexMeta(DocumentMeta): + # global flag to guard us from associating an Index with the base Document + # class, only user defined subclasses should have an _index attr + _document_initialized = False + + def __new__(cls, name, bases, attrs): + new_cls = super(IndexMeta, cls).__new__(cls, name, bases, attrs) + if cls._document_initialized: + index_opts = attrs.pop("Index", None) + index = cls.construct_index(index_opts, bases) + new_cls._index = index + index.document(new_cls) + cls._document_initialized = True + return new_cls + + @classmethod + def construct_index(cls, opts, bases): + if opts is None: + for b in bases: + if hasattr(b, "_index"): + return b._index + + # Set None as Index name so it will set _all while making the query + return Index(name=None) + + i = Index(getattr(opts, "name", "*"), using=getattr(opts, "using", "default")) + i.settings(**getattr(opts, "settings", {})) + i.aliases(**getattr(opts, "aliases", {})) + for a in getattr(opts, "analyzers", ()): + i.analyzer(a) + return i + + +class DocumentOptions(object): + def __init__(self, name, bases, attrs): + meta = attrs.pop("Meta", None) + + # create the mapping instance + self.mapping = getattr(meta, "mapping", Mapping()) + + # register all declared fields into the mapping + for name, value in list(iteritems(attrs)): + if isinstance(value, Field): + self.mapping.field(name, value) + del attrs[name] + + # add all the mappings for meta fields + for name in dir(meta): + if isinstance(getattr(meta, name, None), MetaField): + params = getattr(meta, name) + self.mapping.meta(name, *params.args, **params.kwargs) + + # document inheritance - include the fields from parents' mappings + for b in bases: + if hasattr(b, "_doc_type") and hasattr(b._doc_type, "mapping"): + self.mapping.update(b._doc_type.mapping, update_only=True) + + @property + def name(self): + return self.mapping.properties.name + + +@add_metaclass(DocumentMeta) +class InnerDoc(ObjectBase): + """ + Common class for inner documents like Object or Nested + """ + + @classmethod + def from_es(cls, data, data_only=False): + if data_only: + data = {"_source": data} + return super(InnerDoc, cls).from_es(data) + + +@add_metaclass(IndexMeta) +class Document(ObjectBase): + """ + Model-like class for persisting documents in elasticsearch. + """ + + @classmethod + def _matches(cls, hit): + if cls._index._name is None: + return True + return fnmatch(hit.get("_index", ""), cls._index._name) + + @classmethod + def _get_using(cls, using=None): + return using or cls._index._using + + @classmethod + def _get_connection(cls, using=None): + return get_connection(cls._get_using(using)) + + @classmethod + def _default_index(cls, index=None): + return index or cls._index._name + + @classmethod + async def init(cls, index=None, using=None): + """ + Create the index and populate the mappings in elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + await i.save(using=using) + + def _get_index(self, index=None, required=True): + if index is None: + index = getattr(self.meta, "index", None) + if index is None: + index = getattr(self._index, "_name", None) + if index is None and required: + raise ValidationException("No index") + if index and "*" in index: + raise ValidationException("You cannot write to a wildcard index.") + return index + + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, + ", ".join( + "{}={!r}".format(key, getattr(self.meta, key)) + for key in ("index", "id") + if key in self.meta + ), + ) + + @classmethod + def search(cls, using=None, index=None): + """ + Create an :class:`~elasticsearch_dsl.Search` instance that will search + over this ``Document``. + """ + return Search( + using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] + ) + + @classmethod + async def get(cls, id, using=None, index=None, **kwargs): + """ + Retrieve a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to be retrieved + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.get`` unchanged. + """ + es = cls._get_connection(using) + ensure_async_connection(es, "Document.get") + + doc = await es.get(index=cls._default_index(index), id=id, **kwargs) + if not doc.get("found", False): + return None + return cls.from_es(doc) + + @classmethod + async def mget( + cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs + ): + r""" + Retrieve multiple document by their ``id``\s. Returns a list of instances + in the same order as requested. + + :arg docs: list of ``id``\s of the documents to be retrieved or a list + of document specifications as per + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg missing: what to do when one of the documents requested is not + found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise + ``NotFoundError``) or ``'skip'`` (ignore the missing document). + + Any additional keyword arguments will be passed to + ``Elasticsearch.mget`` unchanged. + """ + if missing not in ("raise", "skip", "none"): + raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + + es = cls._get_connection(using) + ensure_async_connection(es, "Document.mget") + + body = { + "docs": [ + doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} + for doc in docs + ] + } + results = await es.mget(body, index=cls._default_index(index), **kwargs) + + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue + + objs.append(cls.from_es(doc)) + + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) + + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) + + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = "Documents %s not found." % ", ".join(missing_ids) + raise NotFoundError(404, message, {"docs": missing_docs}) + return objs + + async def delete(self, using=None, index=None, **kwargs): + """ + Delete the instance in elasticsearch. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``Elasticsearch.delete`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Document.delete") + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + await es.delete(index=self._get_index(index), **doc_meta) + + def to_dict(self, include_meta=False, skip_empty=True): + """ + Serialize the instance into a dictionary so that it can be saved in elasticsearch. + + :arg include_meta: if set to ``True`` will include all the metadata + (``_index``, ``_id`` etc). Otherwise just the document's + data is serialized. This is useful when passing multiple instances into + ``elasticsearch.helpers.bulk``. + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + """ + d = super(Document, self).to_dict(skip_empty=skip_empty) + if not include_meta: + return d + + meta = {"_" + k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # in case of to_dict include the index unlike save/update/delete + index = self._get_index(required=False) + if index is not None: + meta["_index"] = index + + meta["_source"] = d + return meta + + async def update( + self, + using=None, + index=None, + detect_noop=True, + doc_as_upsert=False, + refresh=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + **fields + ): + """ + Partial update of the document, specify fields you wish to update and + both the instance and the document in elasticsearch will be updated:: + + doc = MyDocument(title='Document Title!') + doc.save() + doc.update(title='New Document Title!') + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg detect_noop: Set to ``False`` to disable noop detection. + :arg refresh: Control when the changes made by this request are visible + to search. Set to ``True`` for immediate effect. + :arg retry_on_conflict: In between the get and indexing phases of the + update, it is possible that another process might have already + updated the same document. By default, the update will fail with a + version conflict exception. The retry_on_conflict parameter + controls how many times to retry the update before finally throwing + an exception. + :arg doc_as_upsert: Instead of sending a partial doc plus an upsert + doc, setting doc_as_upsert to true will use the contents of doc as + the upsert value + + :return operation result noop/updated + """ + body = { + "doc_as_upsert": doc_as_upsert, + "detect_noop": detect_noop, + } + + # scripted update + if script or script_id: + if upsert is not None: + body["upsert"] = upsert + + if script: + script = {"source": script} + else: + script = {"id": script_id} + + script["params"] = fields + + body["script"] = script + body["scripted_upsert"] = scripted_upsert + + # partial document update + else: + if not fields: + raise IllegalOperation( + "You cannot call update() without updating individual fields or a script. " + "If you wish to update the entire object use save()." + ) + + # update given fields locally + merge(self, fields) + + # prepare data for ES + values = self.to_dict() + + # if fields were given: partial update + body["doc"] = {k: values.get(k) for k in fields.keys()} + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + if retry_on_conflict is not None: + doc_meta["retry_on_conflict"] = retry_on_conflict + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + es = self._get_connection(using) + ensure_async_connection(es, "Document.update") + + meta = await es.update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta["result"] + + async def save( + self, using=None, index=None, validate=True, skip_empty=True, **kwargs + ): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + ensure_async_connection(es, "Document.save") + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + meta = await es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta["result"] diff --git a/elasticsearch_dsl/_async/faceted_search.py b/elasticsearch_dsl/_async/faceted_search.py new file mode 100644 index 000000000..535613832 --- /dev/null +++ b/elasticsearch_dsl/_async/faceted_search.py @@ -0,0 +1,426 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime, timedelta + +from six import iteritems, itervalues + +from elasticsearch_dsl.aggs import A +from elasticsearch_dsl.query import MatchAll, Nested, Range, Terms +from elasticsearch_dsl.response import Response +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import AttrDict + +__all__ = [ + "FacetedSearch", + "HistogramFacet", + "TermsFacet", + "DateHistogramFacet", + "RangeFacet", + "NestedFacet", +] + + +class Facet(object): + """ + A facet on faceted search. Wraps and aggregation and provides functionality + to create a filter for selected values and return a list of facet values + from the result of the aggregation. + """ + + agg_type = None + + def __init__(self, metric=None, metric_sort="desc", **kwargs): + self.filter_values = () + self._params = kwargs + self._metric = metric + if metric and metric_sort: + self._params["order"] = {"metric": metric_sort} + + def get_aggregation(self): + """ + Return the aggregation object. + """ + agg = A(self.agg_type, **self._params) + if self._metric: + agg.metric("metric", self._metric) + return agg + + def add_filter(self, filter_values): + """ + Construct a filter. + """ + if not filter_values: + return + + f = self.get_value_filter(filter_values[0]) + for v in filter_values[1:]: + f |= self.get_value_filter(v) + return f + + def get_value_filter(self, filter_value): + """ + Construct a filter for an individual value + """ + pass + + def is_filtered(self, key, filter_values): + """ + Is a filter active on the given key. + """ + return key in filter_values + + def get_value(self, bucket): + """ + return a value representing a bucket. Its key as default. + """ + return bucket["key"] + + def get_metric(self, bucket): + """ + Return a metric, by default doc_count for a bucket. + """ + if self._metric: + return bucket["metric"]["value"] + return bucket["doc_count"] + + def get_values(self, data, filter_values): + """ + Turn the raw bucket data into a list of tuples containing the key, + number of documents and a flag indicating whether this value has been + selected or not. + """ + out = [] + for bucket in data.buckets: + key = self.get_value(bucket) + out.append( + (key, self.get_metric(bucket), self.is_filtered(key, filter_values)) + ) + return out + + +class TermsFacet(Facet): + agg_type = "terms" + + def add_filter(self, filter_values): + """ Create a terms filter instead of bool containing term filters. """ + if filter_values: + return Terms( + _expand__to_dot=False, **{self._params["field"]: filter_values} + ) + + +class RangeFacet(Facet): + agg_type = "range" + + def _range_to_dict(self, range): + key, range = range + out = {"key": key} + if range[0] is not None: + out["from"] = range[0] + if range[1] is not None: + out["to"] = range[1] + return out + + def __init__(self, ranges, **kwargs): + super(RangeFacet, self).__init__(**kwargs) + self._params["ranges"] = list(map(self._range_to_dict, ranges)) + self._params["keyed"] = False + self._ranges = dict(ranges) + + def get_value_filter(self, filter_value): + f, t = self._ranges[filter_value] + limits = {} + if f is not None: + limits["gte"] = f + if t is not None: + limits["lt"] = t + + return Range(_expand__to_dot=False, **{self._params["field"]: limits}) + + +class HistogramFacet(Facet): + agg_type = "histogram" + + def get_value_filter(self, filter_value): + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": filter_value + self._params["interval"], + } + } + ) + + +class DateHistogramFacet(Facet): + agg_type = "date_histogram" + + DATE_INTERVALS = { + "month": lambda d: (d + timedelta(days=32)).replace(day=1), + "week": lambda d: d + timedelta(days=7), + "day": lambda d: d + timedelta(days=1), + "hour": lambda d: d + timedelta(hours=1), + } + + def __init__(self, **kwargs): + kwargs.setdefault("min_doc_count", 0) + super(DateHistogramFacet, self).__init__(**kwargs) + + def get_value(self, bucket): + if not isinstance(bucket["key"], datetime): + # Elasticsearch returns key=None instead of 0 for date 1970-01-01, + # so we need to set key to 0 to avoid TypeError exception + if bucket["key"] is None: + bucket["key"] = 0 + # Preserve milliseconds in the datetime + return datetime.utcfromtimestamp(int(bucket["key"]) / 1000.0) + else: + return bucket["key"] + + def get_value_filter(self, filter_value): + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": self.DATE_INTERVALS[self._params["interval"]](filter_value), + } + } + ) + + +class NestedFacet(Facet): + agg_type = "nested" + + def __init__(self, path, nested_facet): + self._path = path + self._inner = nested_facet + super(NestedFacet, self).__init__( + path=path, aggs={"inner": nested_facet.get_aggregation()} + ) + + def get_values(self, data, filter_values): + return self._inner.get_values(data.inner, filter_values) + + def add_filter(self, filter_values): + inner_q = self._inner.add_filter(filter_values) + if inner_q: + return Nested(path=self._path, query=inner_q) + + +class FacetedResponse(Response): + @property + def query_string(self): + return self._faceted_search._query + + @property + def facets(self): + if not hasattr(self, "_facets"): + super(AttrDict, self).__setattr__("_facets", AttrDict({})) + for name, facet in iteritems(self._faceted_search.facets): + self._facets[name] = facet.get_values( + getattr(getattr(self.aggregations, "_filter_" + name), name), + self._faceted_search.filter_values.get(name, ()), + ) + return self._facets + + +class FacetedSearch(object): + """ + Abstraction for creating faceted navigation searches that takes care of + composing the queries, aggregations and filters as needed as well as + presenting the results in an easy-to-consume fashion:: + + class BlogSearch(FacetedSearch): + index = 'blogs' + doc_types = [Blog, Post] + fields = ['title^5', 'category', 'description', 'body'] + + facets = { + 'type': TermsFacet(field='_type'), + 'category': TermsFacet(field='category'), + 'weekly_posts': DateHistogramFacet(field='published_from', interval='week') + } + + def search(self): + ' Override search to add your own filters ' + s = super(BlogSearch, self).search() + return s.filter('term', published=True) + + # when using: + blog_search = BlogSearch("web framework", filters={"category": "python"}) + + # supports pagination + blog_search[10:20] + + response = blog_search.execute() + + # easy access to aggregation results: + for category, hit_count, is_selected in response.facets.category: + print( + "Category %s has %d hits%s." % ( + category, + hit_count, + ' and is chosen' if is_selected else '' + ) + ) + + """ + + index = None + doc_types = None + fields = None + facets = {} + using = "default" + + def __init__(self, query=None, filters={}, sort=()): + """ + :arg query: the text to search for + :arg filters: facet values to filter + :arg sort: sort information to be passed to :class:`~elasticsearch_dsl.Search` + """ + self._query = query + self._filters = {} + self._sort = sort + self.filter_values = {} + for name, value in iteritems(filters): + self.add_filter(name, value) + + self._s = self.build_search() + + def count(self): + return self._s.count() + + def __getitem__(self, k): + self._s = self._s[k] + return self + + def __iter__(self): + return iter(self._s) + + def add_filter(self, name, filter_values): + """ + Add a filter for a facet. + """ + # normalize the value into a list + if not isinstance(filter_values, (tuple, list)): + if filter_values is None: + return + filter_values = [ + filter_values, + ] + + # remember the filter values for use in FacetedResponse + self.filter_values[name] = filter_values + + # get the filter from the facet + f = self.facets[name].add_filter(filter_values) + if f is None: + return + + self._filters[name] = f + + def search(self): + """ + Returns the base Search object to which the facets are added. + + You can customize the query by overriding this method and returning a + modified search object. + """ + s = Search(doc_type=self.doc_types, index=self.index, using=self.using) + return s.response_class(FacetedResponse) + + def query(self, search, query): + """ + Add query part to ``search``. + + Override this if you wish to customize the query used. + """ + if query: + if self.fields: + return search.query("multi_match", fields=self.fields, query=query) + else: + return search.query("multi_match", query=query) + return search + + def aggregate(self, search): + """ + Add aggregations representing the facets selected, including potential + filters. + """ + for f, facet in iteritems(self.facets): + agg = facet.get_aggregation() + agg_filter = MatchAll() + for field, filter in iteritems(self._filters): + if f == field: + continue + agg_filter &= filter + search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket( + f, agg + ) + + def filter(self, search): + """ + Add a ``post_filter`` to the search request narrowing the results based + on the facet filters. + """ + if not self._filters: + return search + + post_filter = MatchAll() + for f in itervalues(self._filters): + post_filter &= f + return search.post_filter(post_filter) + + def highlight(self, search): + """ + Add highlighting for all the fields + """ + return search.highlight( + *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields) + ) + + def sort(self, search): + """ + Add sorting information to the request. + """ + if self._sort: + search = search.sort(*self._sort) + return search + + def build_search(self): + """ + Construct the ``Search`` object. + """ + s = self.search() + s = self.query(s, self._query) + s = self.filter(s) + if self.fields: + s = self.highlight(s) + s = self.sort(s) + self.aggregate(s) + return s + + async def execute(self): + """ + Execute the search and return the response. + """ + r = await self._s.execute() + r._faceted_search = self + return r diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py new file mode 100644 index 000000000..0aa61701a --- /dev/null +++ b/elasticsearch_dsl/_async/index.py @@ -0,0 +1,721 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl import analysis +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.update_by_query import UpdateByQuery +from elasticsearch_dsl.utils import merge + +from .utils import ensure_async_connection + + +class IndexTemplate(object): + def __init__(self, name, template, index=None, order=None, **kwargs): + if index is None: + self._index = Index(template, **kwargs) + else: + if kwargs: + raise ValueError( + "You cannot specify options for Index when" + " passing an Index instance." + ) + self._index = index.clone() + self._index._name = template + self._template_name = name + self.order = order + + def __getattr__(self, attr_name): + return getattr(self._index, attr_name) + + def to_dict(self): + d = self._index.to_dict() + d["index_patterns"] = [self._index._name] + if self.order is not None: + d["order"] = self.order + return d + + async def save(self, using=None): + es = get_connection(using or self._index._using) + ensure_async_connection(es, "IndexTemplate.save") + + return await es.indices.put_template( + name=self._template_name, body=self.to_dict() + ) + + +class Index(object): + def __init__(self, name, using="default"): + """ + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + self._name = name + self._doc_types = [] + self._using = using + self._settings = {} + self._aliases = {} + self._analysis = {} + self._mapping = None + + def get_or_create_mapping(self): + if self._mapping is None: + self._mapping = Mapping() + return self._mapping + + def as_template(self, template_name, pattern=None, order=None): + # TODO: should we allow pattern to be a top-level arg? + # or maybe have an IndexPattern that allows for it and have + # Document._index be that? + return IndexTemplate( + template_name, pattern or self._name, index=self, order=order + ) + + def resolve_nested(self, field_path): + for doc in self._doc_types: + nested, field = doc._doc_type.mapping.resolve_nested(field_path) + if field is not None: + return nested, field + if self._mapping: + return self._mapping.resolve_nested(field_path) + return (), None + + def resolve_field(self, field_path): + for doc in self._doc_types: + field = doc._doc_type.mapping.resolve_field(field_path) + if field is not None: + return field + if self._mapping: + return self._mapping.resolve_field(field_path) + return None + + async def load_mappings(self, using=None): + mapping = self.get_or_create_mapping() + + await mapping.update_from_es(self._name, using=using or self._using) + + def clone(self, name=None, using=None): + """ + Create a copy of the instance with another name or connection alias. + Useful for creating multiple indices with shared configuration:: + + i = Index('base-index') + i.settings(number_of_shards=1) + i.create() + + i2 = i.clone('other-index') + i2.create() + + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + i = Index(name or self._name, using=using or self._using) + i._settings = self._settings.copy() + i._aliases = self._aliases.copy() + i._analysis = self._analysis.copy() + i._doc_types = self._doc_types[:] + if self._mapping is not None: + i._mapping = self._mapping._clone() + return i + + def _get_connection(self, using=None): + if self._name is None: + raise ValueError("You cannot perform API calls on the default index.") + return get_connection(using or self._using) + + connection = property(_get_connection) + + def mapping(self, mapping): + """ + Associate a mapping (an instance of + :class:`~elasticsearch_dsl.Mapping`) with this index. + This means that, when this index is created, it will contain the + mappings for the document type defined by those mappings. + """ + self.get_or_create_mapping().update(mapping) + + def document(self, document): + """ + Associate a :class:`~elasticsearch_dsl.Document` subclass with an index. + This means that, when this index is created, it will contain the + mappings for the ``Document``. If the ``Document`` class doesn't have a + default index yet (by defining ``class Index``), this instance will be + used. Can be used as a decorator:: + + i = Index('blog') + + @i.document + class Post(Document): + title = Text() + + # create the index, including Post mappings + i.create() + + # .search() will now return a Search object that will return + # properly deserialized Post instances + s = i.search() + """ + self._doc_types.append(document) + + # If the document index does not have any name, that means the user + # did not set any index already to the document. + # So set this index as document index + if document._index._name is None: + document._index = self + + return document + + def settings(self, **kwargs): + """ + Add settings to the index:: + + i = Index('i') + i.settings(number_of_shards=1, number_of_replicas=0) + + Multiple calls to ``settings`` will merge the keys, later overriding + the earlier. + """ + self._settings.update(kwargs) + return self + + def aliases(self, **kwargs): + """ + Add aliases to the index definition:: + + i = Index('blog-v2') + i.aliases(blog={}, published={'filter': Q('term', published=True)}) + """ + self._aliases.update(kwargs) + return self + + def analyzer(self, *args, **kwargs): + """ + Explicitly add an analyzer to an index. Note that all custom analyzers + defined in mappings will also be created. This is useful for search analyzers. + + Example:: + + from elasticsearch_dsl import analyzer, tokenizer + + my_analyzer = analyzer('my_analyzer', + tokenizer=tokenizer('trigram', 'nGram', min_gram=3, max_gram=3), + filter=['lowercase'] + ) + + i = Index('blog') + i.analyzer(my_analyzer) + + """ + analyzer = analysis.analyzer(*args, **kwargs) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + return + + # merge the definition + merge(self._analysis, d, True) + + def to_dict(self): + out = {} + if self._settings: + out["settings"] = self._settings + if self._aliases: + out["aliases"] = self._aliases + mappings = self._mapping.to_dict() if self._mapping else {} + analysis = self._mapping._collect_analysis() if self._mapping else {} + for d in self._doc_types: + mapping = d._doc_type.mapping + merge(mappings, mapping.to_dict(), True) + merge(analysis, mapping._collect_analysis(), True) + if mappings: + out["mappings"] = mappings + if analysis or self._analysis: + merge(analysis, self._analysis) + out.setdefault("settings", {})["analysis"] = analysis + return out + + def search(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.Search` object searching over the + index (or all the indices belonging to this template) and its + ``Document``\\s. + """ + return Search( + using=using or self._using, index=self._name, doc_type=self._doc_types + ) + + def updateByQuery(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.UpdateByQuery` object searching over the index + (or all the indices belonging to this template) and updating Documents that match + the search criteria. + + For more information, see here: + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update-by-query.html + """ + return UpdateByQuery( + using=using or self._using, + index=self._name, + ) + + async def create(self, using=None, **kwargs): + """ + Creates the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.create`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.create") + + return await es.indices.create(index=self._name, body=self.to_dict(), **kwargs) + + async def is_closed(self, using=None): + es = self._get_connection(using) + ensure_async_connection(es, "Index.is_closed") + + state = await es.cluster.state( + index=self._name, + metric="metadata", + ) + + return state["metadata"]["indices"][self._name]["state"] == "close" + + async def save(self, using=None): + """ + Sync the index definition with elasticsearch, creating the index if it + doesn't exist and updating its settings and mappings if it does. + + Note some settings and mapping changes cannot be done on an open + index (or at all on an existing index) and for those this method will + fail with the underlying exception. + """ + if not await self.exists(using=using): + return await self.create(using=using) + + body = self.to_dict() + settings = body.pop("settings", {}) + analysis = settings.pop("analysis", None) + current_settings = self.get_settings(using=using)[self._name]["settings"][ + "index" + ] + if analysis: + if await self.is_closed(using=using): + # closed index, update away + settings["analysis"] = analysis + else: + # compare analysis definition, if all analysis objects are + # already defined as requested, skip analysis update and + # proceed, otherwise raise IllegalOperation + existing_analysis = current_settings.get("analysis", {}) + if any( + existing_analysis.get(section, {}).get(k, None) + != analysis[section][k] + for section in analysis + for k in analysis[section] + ): + raise IllegalOperation( + "You cannot update analysis configuration on an open index, " + "you need to close index %s first." % self._name + ) + + # try and update the settings + if settings: + settings = settings.copy() + for k, v in list(settings.items()): + if k in current_settings and current_settings[k] == str(v): + del settings[k] + + if settings: + await self.put_settings(using=using, body=settings) + + # update the mappings, any conflict in the mappings will result in an + # exception + mappings = body.pop("mappings", {}) + if mappings: + await self.put_mapping(using=using, body=mappings) + + async def analyze(self, using=None, **kwargs): + """ + Perform the analysis process on a text and return the tokens breakdown + of the text. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.analyze`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.analyze") + + return await es.indices.analyze(index=self._name, **kwargs) + + async def refresh(self, using=None, **kwargs): + """ + Performs a refresh operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.refresh`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.refresh") + + return await es.indices.refresh(index=self._name, **kwargs) + + async def flush(self, using=None, **kwargs): + """ + Performs a flush operation on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.flush") + + return await es.indices.flush(index=self._name, **kwargs) + + async def get(self, using=None, **kwargs): + """ + The get index API allows to retrieve information about the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get") + + return await es.indices.get(index=self._name, **kwargs) + + async def open(self, using=None, **kwargs): + """ + Opens the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.open`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.open") + + return await es.indices.open(index=self._name, **kwargs) + + async def close(self, using=None, **kwargs): + """ + Closes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.close`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.close") + + return await es.indices.close(index=self._name, **kwargs) + + async def delete(self, using=None, **kwargs): + """ + Deletes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.delete") + + return await es.indices.delete(index=self._name, **kwargs) + + async def exists(self, using=None, **kwargs): + """ + Returns ``True`` if the index already exists in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.exists") + + return await es.indices.exists(index=self._name, **kwargs) + + async def exists_type(self, using=None, **kwargs): + """ + Check if a type/types exists in the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_type`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.exists_type") + + return await es.indices.exists_type(index=self._name, **kwargs) + + async def put_mapping(self, using=None, **kwargs): + """ + Register specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.put_mapping") + + return await es.indices.put_mapping(index=self._name, **kwargs) + + async def get_mapping(self, using=None, **kwargs): + """ + Retrieve specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get_mapping") + + return await es.indices.get_mapping(index=self._name, **kwargs) + + async def get_field_mapping(self, using=None, **kwargs): + """ + Retrieve mapping definition of a specific field. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_field_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get_field_mapping") + + return await es.indices.get_field_mapping(index=self._name, **kwargs) + + async def put_alias(self, using=None, **kwargs): + """ + Create an alias for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.put_alias") + + return await es.indices.put_alias(index=self._name, **kwargs) + + def exists_alias(self, using=None, **kwargs): + """ + Return a boolean indicating whether given alias exists for this index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_alias`` unchanged. + """ + return self._get_connection(using).indices.exists_alias( + index=self._name, **kwargs + ) + + async def get_alias(self, using=None, **kwargs): + """ + Retrieve a specified alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get_alias") + + return await es.indices.get_alias(index=self._name, **kwargs) + + async def delete_alias(self, using=None, **kwargs): + """ + Delete specific alias. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.delete_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.delete_alias") + + return await es.indices.delete_alias(index=self._name, **kwargs) + + async def get_settings(self, using=None, **kwargs): + """ + Retrieve settings for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_settings`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get_settings") + + return await es.indices.get_settings(index=self._name, **kwargs) + + async def put_settings(self, using=None, **kwargs): + """ + Change specific index level settings in real time. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.put_settings`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.put_settings") + + return await es.indices.put_settings(index=self._name, **kwargs) + + async def stats(self, using=None, **kwargs): + """ + Retrieve statistics on different operations happening on the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.stats`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.stats") + + return await es.indices.stats(index=self._name, **kwargs) + + async def segments(self, using=None, **kwargs): + """ + Provide low level segments information that a Lucene index (shard + level) is built with. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.segments`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.segments") + + return await es.indices.segments(index=self._name, **kwargs) + + async def validate_query(self, using=None, **kwargs): + """ + Validate a potentially expensive query without executing it. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.validate_query`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.validate_query") + + return await es.indices.validate_query(index=self._name, **kwargs) + + async def clear_cache(self, using=None, **kwargs): + """ + Clear all caches or specific cached associated with the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.clear_cache`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.clear_cache") + + return await es.indices.clear_cache(index=self._name, **kwargs) + + async def recovery(self, using=None, **kwargs): + """ + The indices recovery API provides insight into on-going shard + recoveries for the index. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.recovery`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.recovery") + + return await es.indices.recovery(index=self._name, **kwargs) + + async def upgrade(self, using=None, **kwargs): + """ + Upgrade the index to the latest format. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.upgrade`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.upgrade") + + return await es.indices.upgrade(index=self._name, **kwargs) + + async def get_upgrade(self, using=None, **kwargs): + """ + Monitor how much of the index is upgraded. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.get_upgrade`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.get_upgrade") + + return await es.indices.get_upgrade(index=self._name, **kwargs) + + async def flush_synced(self, using=None, **kwargs): + """ + Perform a normal flush, then add a generated unique marker (sync_id) to + all shards. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.flush_synced`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.flush_synced") + + return await es.indices.flush_synced(index=self._name, **kwargs) + + async def shard_stores(self, using=None, **kwargs): + """ + Provides store information for shard copies of the index. Store + information reports on which nodes shard copies exist, the shard copy + version, indicating how recent they are, and any exceptions encountered + while opening the shard index or from earlier engine failure. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shard_stores`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.shard_stores") + + return await es.indices.shard_stores(index=self._name, **kwargs) + + async def forcemerge(self, using=None, **kwargs): + """ + The force merge API allows to force merging of the index through an + API. The merge relates to the number of segments a Lucene index holds + within each shard. The force merge operation allows to reduce the + number of segments by merging them. + + This call will block until the merge is complete. If the http + connection is lost, the request will continue in the background, and + any new requests will block until the previous force merge is complete. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.forcemerge`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.forcemerge") + + return await es.indices.forcemerge(index=self._name, **kwargs) + + async def shrink(self, using=None, **kwargs): + """ + The shrink index API allows you to shrink an existing index into a new + index with fewer primary shards. The number of primary shards in the + target index must be a factor of the shards in the source index. For + example an index with 8 primary shards can be shrunk into 4, 2 or 1 + primary shards or an index with 15 primary shards can be shrunk into 5, + 3 or 1. If the number of shards in the index is a prime number it can + only be shrunk into a single primary shard. Before shrinking, a + (primary or replica) copy of every shard in the index must be present + on the same node. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shrink`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.shrink") + + return await es.indices.shrink(index=self._name, **kwargs) diff --git a/elasticsearch_dsl/_async/mapping.py b/elasticsearch_dsl/_async/mapping.py new file mode 100644 index 000000000..ba099ec63 --- /dev/null +++ b/elasticsearch_dsl/_async/mapping.py @@ -0,0 +1,244 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc + +from itertools import chain + +from six import iteritems, itervalues + +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.field import Nested, Text, construct_field +from elasticsearch_dsl.utils import DslBase + +from .utils import ensure_async_connection + +META_FIELDS = frozenset( + ( + "dynamic", + "transform", + "dynamic_date_formats", + "date_detection", + "numeric_detection", + "dynamic_templates", + "enabled", + ) +) + + +class Properties(DslBase): + name = "properties" + _param_defs = {"properties": {"type": "field", "hash": True}} + + def __init__(self): + super(Properties, self).__init__() + + def __repr__(self): + return "Properties()" + + def __getitem__(self, name): + return self.properties[name] + + def __contains__(self, name): + return name in self.properties + + def to_dict(self): + return super(Properties, self).to_dict()["properties"] + + def field(self, name, *args, **kwargs): + self.properties[name] = construct_field(*args, **kwargs) + return self + + def _collect_fields(self): + """ Iterate over all Field objects within, including multi fields. """ + for f in itervalues(self.properties.to_dict()): + yield f + # multi fields + if hasattr(f, "fields"): + for inner_f in itervalues(f.fields.to_dict()): + yield inner_f + # nested and inner objects + if hasattr(f, "_collect_fields"): + for inner_f in f._collect_fields(): + yield inner_f + + def update(self, other_object): + if not hasattr(other_object, "properties"): + # not an inner/nested object, no merge possible + return + + our, other = self.properties, other_object.properties + for name in other: + if name in our: + if hasattr(our[name], "update"): + our[name].update(other[name]) + continue + our[name] = other[name] + + +class Mapping(object): + def __init__(self): + self.properties = Properties() + self._meta = {} + + def __repr__(self): + return "Mapping()" + + def _clone(self): + m = Mapping() + m.properties._params = self.properties._params.copy() + return m + + @classmethod + async def from_es(cls, index, using="default"): + m = cls() + await m.update_from_es(index, using) + + return m + + def resolve_nested(self, field_path): + field = self + nested = [] + parts = field_path.split(".") + for i, step in enumerate(parts): + try: + field = field[step] + except KeyError: + return (), None + if isinstance(field, Nested): + nested.append(".".join(parts[: i + 1])) + return nested, field + + def resolve_field(self, field_path): + field = self + for step in field_path.split("."): + try: + field = field[step] + except KeyError: + return + return field + + def _collect_analysis(self): + analysis = {} + fields = [] + if "_all" in self._meta: + fields.append(Text(**self._meta["_all"])) + + for f in chain(fields, self.properties._collect_fields()): + for analyzer_name in ( + "analyzer", + "normalizer", + "search_analyzer", + "search_quote_analyzer", + ): + if not hasattr(f, analyzer_name): + continue + analyzer = getattr(f, analyzer_name) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + continue + + # merge the definition + # TODO: conflict detection/resolution + for key in d: + analysis.setdefault(key, {}).update(d[key]) + + return analysis + + async def save(self, index, using="default"): + from .index import Index + + index = Index(index, using=using) + index.mapping(self) + return await index.save() + + async def update_from_es(self, index, using="default"): + es = get_connection(using) + ensure_async_connection(es, "Mapping.update_from_es") + + raw = await es.indices.get_mapping(index=index) + _, raw = raw.popitem() + self._update_from_dict(raw["mappings"]) + + def _update_from_dict(self, raw): + for name, definition in iteritems(raw.get("properties", {})): + self.field(name, definition) + + # metadata like _all etc + for name, value in iteritems(raw): + if name != "properties": + if isinstance(value, collections_abc.Mapping): + self.meta(name, **value) + else: + self.meta(name, value) + + def update(self, mapping, update_only=False): + for name in mapping: + if update_only and name in self: + # nested and inner objects, merge recursively + if hasattr(self[name], "update"): + # FIXME only merge subfields, not the settings + self[name].update(mapping[name], update_only) + continue + self.field(name, mapping[name]) + + if update_only: + for name in mapping._meta: + if name not in self._meta: + self._meta[name] = mapping._meta[name] + else: + self._meta.update(mapping._meta) + + def __contains__(self, name): + return name in self.properties.properties + + def __getitem__(self, name): + return self.properties.properties[name] + + def __iter__(self): + return iter(self.properties.properties) + + def field(self, *args, **kwargs): + self.properties.field(*args, **kwargs) + return self + + def meta(self, name, params=None, **kwargs): + if not name.startswith("_") and name not in META_FIELDS: + name = "_" + name + + if params and kwargs: + raise ValueError("Meta configs cannot have both value and a dictionary.") + + self._meta[name] = kwargs if params is None else params + return self + + def to_dict(self): + meta = self._meta + + # hard coded serialization of analyzers in _all + if "_all" in meta: + meta = meta.copy() + _all = meta["_all"] = meta["_all"].copy() + for f in ("analyzer", "search_analyzer", "search_quote_analyzer"): + if hasattr(_all.get(f, None), "to_dict"): + _all[f] = _all[f].to_dict() + meta.update(self.properties.to_dict()) + return meta diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py new file mode 100644 index 000000000..87f04ea06 --- /dev/null +++ b/elasticsearch_dsl/_async/search.py @@ -0,0 +1,826 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import copy + +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc + +from elasticsearch.exceptions import TransportError +from elasticsearch.helpers import async_scan +from six import iteritems, string_types + +from elasticsearch_dsl.aggs import A, AggBase +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import Hit, Response +from elasticsearch_dsl.utils import AttrDict, DslBase + +from .utils import ensure_async_connection + + +class QueryProxy(object): + """ + Simple proxy around DSL objects (queries) that can be called + (to add query/post_filter) and also allows attribute access which is proxied to + the wrapped query. + """ + + def __init__(self, search, attr_name): + self._search = search + self._proxied = None + self._attr_name = attr_name + + def __nonzero__(self): + return self._proxied is not None + + __bool__ = __nonzero__ + + def __call__(self, *args, **kwargs): + s = self._search._clone() + + # we cannot use self._proxied since we just cloned self._search and + # need to access the new self on the clone + proxied = getattr(s, self._attr_name) + if proxied._proxied is None: + proxied._proxied = Q(*args, **kwargs) + else: + proxied._proxied &= Q(*args, **kwargs) + + # always return search to be chainable + return s + + def __getattr__(self, attr_name): + return getattr(self._proxied, attr_name) + + def __setattr__(self, attr_name, value): + if not attr_name.startswith("_"): + self._proxied = Q(self._proxied.to_dict()) + setattr(self._proxied, attr_name, value) + super(QueryProxy, self).__setattr__(attr_name, value) + + def __getstate__(self): + return self._search, self._proxied, self._attr_name + + def __setstate__(self, state): + self._search, self._proxied, self._attr_name = state + + +class ProxyDescriptor(object): + """ + Simple descriptor to enable setting of queries and filters as: + + s = Search() + s.query = Q(...) + + """ + + def __init__(self, name): + self._attr_name = "_%s_proxy" % name + + def __get__(self, instance, owner): + return getattr(instance, self._attr_name) + + def __set__(self, instance, value): + proxy = getattr(instance, self._attr_name) + proxy._proxied = Q(value) + + +class AggsProxy(AggBase, DslBase): + name = "aggs" + + def __init__(self, search): + self._base = self + self._search = search + self._params = {"aggs": {}} + + def to_dict(self): + return super(AggsProxy, self).to_dict().get("aggs", {}) + + +class Request(object): + def __init__(self, using="default", index=None, doc_type=None, extra=None): + self._using = using + + self._index = None + if isinstance(index, (tuple, list)): + self._index = list(index) + elif index: + self._index = [index] + + self._doc_type = [] + self._doc_type_map = {} + if isinstance(doc_type, (tuple, list)): + self._doc_type.extend(doc_type) + elif isinstance(doc_type, collections_abc.Mapping): + self._doc_type.extend(doc_type.keys()) + self._doc_type_map.update(doc_type) + elif doc_type: + self._doc_type.append(doc_type) + + self._params = {} + self._extra = extra or {} + + def __eq__(self, other): + return ( + isinstance(other, Request) + and other._params == self._params + and other._index == self._index + and other._doc_type == self._doc_type + and other.to_dict() == self.to_dict() + ) + + def __copy__(self): + return self._clone() + + def params(self, **kwargs): + """ + Specify query params to be used when executing the search. All the + keyword arguments will override the current values. See + https://elasticsearch-py.readthedocs.io/en/master/api.html#elasticsearch.Elasticsearch.search + for all available parameters. + + Example:: + + s = Search() + s = s.params(routing='user-1', preference='local') + """ + s = self._clone() + s._params.update(kwargs) + return s + + def index(self, *index): + """ + Set the index for the search. If called empty it will remove all information. + + Example: + + s = Search() + s = s.index('twitter-2015.01.01', 'twitter-2015.01.02') + s = s.index(['twitter-2015.01.01', 'twitter-2015.01.02']) + """ + # .index() resets + s = self._clone() + if not index: + s._index = None + else: + indexes = [] + for i in index: + if isinstance(i, string_types): + indexes.append(i) + elif isinstance(i, list): + indexes += i + elif isinstance(i, tuple): + indexes += list(i) + + s._index = (self._index or []) + indexes + + return s + + def _resolve_field(self, path): + for dt in self._doc_type: + if not hasattr(dt, "_index"): + continue + field = dt._index.resolve_field(path) + if field is not None: + return field + + def _resolve_nested(self, hit, parent_class=None): + doc_class = Hit + + nested_path = [] + nesting = hit["_nested"] + while nesting and "field" in nesting: + nested_path.append(nesting["field"]) + nesting = nesting.get("_nested") + nested_path = ".".join(nested_path) + + if hasattr(parent_class, "_index"): + nested_field = parent_class._index.resolve_field(nested_path) + else: + nested_field = self._resolve_field(nested_path) + + if nested_field is not None: + return nested_field._doc_class + + return doc_class + + def _get_result(self, hit, parent_class=None): + doc_class = Hit + dt = hit.get("_type") + + if "_nested" in hit: + doc_class = self._resolve_nested(hit, parent_class) + + elif dt in self._doc_type_map: + doc_class = self._doc_type_map[dt] + + else: + for doc_type in self._doc_type: + if hasattr(doc_type, "_matches") and doc_type._matches(hit): + doc_class = doc_type + break + + for t in hit.get("inner_hits", ()): + hit["inner_hits"][t] = Response( + self, hit["inner_hits"][t], doc_class=doc_class + ) + + callback = getattr(doc_class, "from_es", doc_class) + return callback(hit) + + def doc_type(self, *doc_type, **kwargs): + """ + Set the type to search through. You can supply a single value or + multiple. Values can be strings or subclasses of ``Document``. + + You can also pass in any keyword arguments, mapping a doc_type to a + callback that should be used instead of the Hit class. + + If no doc_type is supplied any information stored on the instance will + be erased. + + Example: + + s = Search().doc_type('product', 'store', User, custom=my_callback) + """ + # .doc_type() resets + s = self._clone() + if not doc_type and not kwargs: + s._doc_type = [] + s._doc_type_map = {} + else: + s._doc_type.extend(doc_type) + s._doc_type.extend(kwargs.keys()) + s._doc_type_map.update(kwargs) + return s + + def using(self, client): + """ + Associate the search request with an elasticsearch client. A fresh copy + will be returned with current instance remaining unchanged. + + :arg client: an instance of ``elasticsearch.Elasticsearch`` to use or + an alias to look up in ``elasticsearch_dsl.connections`` + + """ + s = self._clone() + s._using = client + return s + + def extra(self, **kwargs): + """ + Add extra keys to the request body. Mostly here for backwards + compatibility. + """ + s = self._clone() + if "from_" in kwargs: + kwargs["from"] = kwargs.pop("from_") + s._extra.update(kwargs) + return s + + def _clone(self): + s = self.__class__( + using=self._using, index=self._index, doc_type=self._doc_type + ) + s._doc_type_map = self._doc_type_map.copy() + s._extra = self._extra.copy() + s._params = self._params.copy() + return s + + +class Search(Request): + query = ProxyDescriptor("query") + post_filter = ProxyDescriptor("post_filter") + + def __init__(self, **kwargs): + """ + Search request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overridden by methods (`using`, `index` and `doc_type` respectively). + """ + super(Search, self).__init__(**kwargs) + + self.aggs = AggsProxy(self) + self._sort = [] + self._source = None + self._highlight = {} + self._highlight_opts = {} + self._suggest = {} + self._script_fields = {} + self._response_class = Response + + self._query_proxy = QueryProxy(self, "query") + self._post_filter_proxy = QueryProxy(self, "post_filter") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) + + async def __aiter__(self): + """ + Iterate over the hits. + """ + return iter(self.execute()) + + def __getitem__(self, n): + """ + Support slicing the `Search` instance for pagination. + + Slicing equates to the from/size parameters. E.g.:: + + s = Search().query(...)[0:25] + + is equivalent to:: + + s = Search().query(...).extra(from_=0, size=25) + + """ + s = self._clone() + + if isinstance(n, slice): + # If negative slicing, abort. + if n.start and n.start < 0 or n.stop and n.stop < 0: + raise ValueError("Search does not support negative slicing.") + # Elasticsearch won't get all results so we default to size: 10 if + # stop not given. + s._extra["from"] = n.start or 0 + s._extra["size"] = max( + 0, n.stop - (n.start or 0) if n.stop is not None else 10 + ) + return s + else: # This is an index lookup, equivalent to slicing by [n:n+1]. + # If negative index, abort. + if n < 0: + raise ValueError("Search does not support negative indexing.") + s._extra["from"] = n + s._extra["size"] = 1 + return s + + @classmethod + def from_dict(cls, d): + """ + Construct a new `Search` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + s = Search.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "aggs": {...} + }) + s = s.filter('term', published=True) + """ + s = cls() + s.update_from_dict(d) + return s + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + s = super(Search, self)._clone() + + s._response_class = self._response_class + s._sort = self._sort[:] + s._source = copy.copy(self._source) if self._source is not None else None + s._highlight = self._highlight.copy() + s._highlight_opts = self._highlight_opts.copy() + s._suggest = self._suggest.copy() + s._script_fields = self._script_fields.copy() + for x in ("query", "post_filter"): + getattr(s, x)._proxied = getattr(self, x)._proxied + + # copy top-level bucket definitions + if self.aggs._params.get("aggs"): + s.aggs._params = {"aggs": self.aggs._params["aggs"].copy()} + return s + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + s = self._clone() + s._response_class = cls + return s + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "post_filter" in d: + self.post_filter._proxied = Q(d.pop("post_filter")) + + aggs = d.pop("aggs", d.pop("aggregations", {})) + if aggs: + self.aggs._params = { + "aggs": {name: A(value) for (name, value) in iteritems(aggs)} + } + if "sort" in d: + self._sort = d.pop("sort") + if "_source" in d: + self._source = d.pop("_source") + if "highlight" in d: + high = d.pop("highlight").copy() + self._highlight = high.pop("fields") + self._highlight_opts = high + if "suggest" in d: + self._suggest = d.pop("suggest") + if "text" in self._suggest: + text = self._suggest.pop("text") + for s in self._suggest.values(): + s.setdefault("text", text) + if "script_fields" in d: + self._script_fields = d.pop("script_fields") + self._extra.update(d) + return self + + def script_fields(self, **kwargs): + """ + Define script fields to be calculated on hits. See + https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-script-fields.html + for more details. + + Example:: + + s = Search() + s = s.script_fields(times_two="doc['field'].value * 2") + s = s.script_fields( + times_three={ + 'script': { + 'lang': 'painless', + 'source': "doc['field'].value * params.n", + 'params': {'n': 3} + } + } + ) + + """ + s = self._clone() + for name in kwargs: + if isinstance(kwargs[name], string_types): + kwargs[name] = {"script": kwargs[name]} + s._script_fields.update(kwargs) + return s + + def source(self, fields=None, **kwargs): + """ + Selectively control how the _source field is returned. + + :arg fields: wildcard string, array of wildcards, or dictionary of includes and excludes + + If ``fields`` is None, the entire document will be returned for + each hit. If fields is a dictionary with keys of 'includes' and/or + 'excludes' the fields will be either included or excluded appropriately. + + Calling this multiple times with the same named parameter will override the + previous values with the new ones. + + Example:: + + s = Search() + s = s.source(includes=['obj1.*'], excludes=["*.description"]) + + s = Search() + s = s.source(includes=['obj1.*']).source(excludes=["*.description"]) + + """ + s = self._clone() + + if fields and kwargs: + raise ValueError("You cannot specify fields and kwargs at the same time.") + + if fields is not None: + s._source = fields + return s + + if kwargs and not isinstance(s._source, dict): + s._source = {} + + for key, value in kwargs.items(): + if value is None: + try: + del s._source[key] + except KeyError: + pass + else: + s._source[key] = value + + return s + + def sort(self, *keys): + """ + Add sorting information to the search request. If called without + arguments it will remove all sort requirements. Otherwise it will + replace them. Acceptable arguments are:: + + 'some.field' + '-some.other.field' + {'different.field': {'any': 'dict'}} + + so for example:: + + s = Search().sort( + 'category', + '-title', + {"price" : {"order" : "asc", "mode" : "avg"}} + ) + + will sort by ``category``, ``title`` (in descending order) and + ``price`` in ascending order using the ``avg`` mode. + + The API returns a copy of the Search object and can thus be chained. + """ + s = self._clone() + s._sort = [] + for k in keys: + if isinstance(k, string_types) and k.startswith("-"): + if k[1:] == "_score": + raise IllegalOperation("Sorting by `-_score` is not allowed.") + k = {k[1:]: {"order": "desc"}} + s._sort.append(k) + return s + + def highlight_options(self, **kwargs): + """ + Update the global highlighting options used for this request. For + example:: + + s = Search() + s = s.highlight_options(order='score') + """ + s = self._clone() + s._highlight_opts.update(kwargs) + return s + + def highlight(self, *fields, **kwargs): + """ + Request highlighting of some fields. All keyword arguments passed in will be + used as parameters for all the fields in the ``fields`` parameter. Example:: + + Search().highlight('title', 'body', fragment_size=50) + + will produce the equivalent of:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 50}, + "title": {"fragment_size": 50} + } + } + } + + If you want to have different options for different fields + you can call ``highlight`` twice:: + + Search().highlight('title', fragment_size=50).highlight('body', fragment_size=100) + + which will produce:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 100}, + "title": {"fragment_size": 50} + } + } + } + + """ + s = self._clone() + for f in fields: + s._highlight[f] = kwargs + return s + + def suggest(self, name, text, **kwargs): + """ + Add a suggestions request to the search. + + :arg name: name of the suggestion + :arg text: text to suggest on + + All keyword arguments will be added to the suggestions body. For example:: + + s = Search() + s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + """ + s = self._clone() + s._suggest[name] = {"text": text} + s._suggest[name].update(kwargs) + return s + + def to_dict(self, count=False, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request's body. + + :arg count: a flag to specify if we are interested in a body for count - + no aggregations, no pagination bounds etc. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + + if self.query: + d["query"] = self.query.to_dict() + + # count request doesn't care for sorting and other things + if not count: + if self.post_filter: + d["post_filter"] = self.post_filter.to_dict() + + if self.aggs.aggs: + d.update(self.aggs.to_dict()) + + if self._sort: + d["sort"] = self._sort + + d.update(self._extra) + + if self._source not in (None, {}): + d["_source"] = self._source + + if self._highlight: + d["highlight"] = {"fields": self._highlight} + d["highlight"].update(self._highlight_opts) + + if self._suggest: + d["suggest"] = self._suggest + + if self._script_fields: + d["script_fields"] = self._script_fields + + d.update(kwargs) + return d + + def count(self): + """ + Return the number of hits matching the query and filters. Note that + only the actual number is returned. + """ + if hasattr(self, "_response") and self._response.hits.total.relation == "eq": + return self._response.hits.total.value + + es = get_connection(self._using) + + d = self.to_dict(count=True) + # TODO: failed shards detection + return es.count(index=self._index, body=d, **self._params)["count"] + + async def execute(self, ignore_cache=False): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + + :arg ignore_cache: if set to ``True``, consecutive calls will hit + ES, while cached result will be ignored. Defaults to `False` + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + ensure_async_connection(es, "Search.execute") + + self._response = self._response_class( + self, + await es.search(index=self._index, body=self.to_dict(), **self._params), + ) + + return self._response + + async def scan(self): + """ + Turn the search into a scan search and return a generator that will + iterate over all the documents matching the query. + + Use ``params`` method to specify any additional arguments you with to + pass to the underlying ``scan`` helper from ``elasticsearch-py`` - + https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + + """ + es = get_connection(self._using) + ensure_async_connection(es, "Search.scan") + + for hit in await async_scan( + es, query=self.to_dict(), index=self._index, **self._params + ): + yield self._get_result(hit) + + async def delete(self): + """ + delete() executes the query by delegating to delete_by_query() + """ + es = get_connection(self._using) + ensure_async_connection(es, "Search.delete") + + return AttrDict( + await es.delete_by_query( + index=self._index, body=self.to_dict(), **self._params + ) + ) + + +class MultiSearch(Request): + """ + Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single + request. + """ + + def __init__(self, **kwargs): + super(MultiSearch, self).__init__(**kwargs) + self._searches = [] + + def __getitem__(self, key): + return self._searches[key] + + def __iter__(self): + return iter(self._searches) + + def _clone(self): + ms = super(MultiSearch, self)._clone() + ms._searches = self._searches[:] + return ms + + def add(self, search): + """ + Adds a new :class:`~elasticsearch_dsl.Search` object to the request:: + + ms = MultiSearch(index='my-index') + ms = ms.add(Search(doc_type=Category).filter('term', category='python')) + ms = ms.add(Search(doc_type=Blog)) + """ + ms = self._clone() + ms._searches.append(search) + return ms + + def to_dict(self): + out = [] + for s in self._searches: + meta = {} + if s._index: + meta["index"] = s._index + meta.update(s._params) + + out.append(meta) + out.append(s.to_dict()) + + return out + + async def execute(self, ignore_cache=False, raise_on_error=True): + """ + Execute the multi search request and return a list of search results. + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + ensure_async_connection(es, "MultiSearch.execute") + + responses = await es.msearch( + index=self._index, body=self.to_dict(), **self._params + ) + + out = [] + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise TransportError("N/A", r["error"]["type"], r["error"]) + r = None + else: + r = Response(s, r) + out.append(r) + + self._response = out + + return self._response diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py new file mode 100644 index 000000000..5f1bf873b --- /dev/null +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -0,0 +1,165 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search import ProxyDescriptor, QueryProxy, Request + +from .utils import ensure_async_connection + + +class UpdateByQuery(Request): + + query = ProxyDescriptor("query") + + def __init__(self, **kwargs): + """ + Update by query request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overriden by methods (`using`, `index` and `doc_type` respectively). + + """ + super(UpdateByQuery, self).__init__(**kwargs) + self._response_class = UpdateByQueryResponse + self._script = {} + self._query_proxy = QueryProxy(self, "query") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) + + @classmethod + def from_dict(cls, d): + """ + Construct a new `UpdateByQuery` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + ubq = UpdateByQuery.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "script": {...} + }) + ubq = ubq.filter('term', published=True) + """ + u = cls() + u.update_from_dict(d) + return u + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + ubq = super(UpdateByQuery, self)._clone() + + ubq._response_class = self._response_class + ubq._script = self._script.copy() + ubq.query._proxied = self.query._proxied + return ubq + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + ubq = self._clone() + ubq._response_class = cls + return ubq + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "script" in d: + self._script = d.pop("script") + self._extra.update(d) + return self + + def script(self, **kwargs): + """ + Define update action to take: + https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-using.html + for more details. + + Note: the API only accepts a single script, so + calling the script multiple times will overwrite. + + Example:: + + ubq = Search() + ubq = ubq.script(source="ctx._source.likes++"") + ubq = ubq.script(source="ctx._source.likes += params.f"", + lang="expression", + params={'f': 3}) + """ + ubq = self._clone() + if ubq._script: + ubq._script = {} + ubq._script.update(kwargs) + return ubq + + def to_dict(self, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request'ubq body. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + if self.query: + d["query"] = self.query.to_dict() + + if self._script: + d["script"] = self._script + + d.update(self._extra) + + d.update(kwargs) + return d + + async def execute(self): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + """ + es = get_connection(self._using) + ensure_async_connection(es, "MultiSearch.execute") + + self._response = self._response_class( + self, + await es.update_by_query( + index=self._index, body=self.to_dict(), **self._params + ), + ) + return self._response diff --git a/elasticsearch_dsl/_async/utils.py b/elasticsearch_dsl/_async/utils.py new file mode 100644 index 000000000..6b91de6c9 --- /dev/null +++ b/elasticsearch_dsl/_async/utils.py @@ -0,0 +1,26 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from elasticsearch import AsyncElasticsearch + + +def ensure_async_connection(es, fn_label): + if not isinstance(es, AsyncElasticsearch): + raise TypeError( + f"{fn_label} can only be used with the elasticsearch.AsyncElasticsearch " + "client" + ) diff --git a/elasticsearch_dsl/connections.py b/elasticsearch_dsl/connections.py index 57ba46f1d..f826517be 100644 --- a/elasticsearch_dsl/connections.py +++ b/elasticsearch_dsl/connections.py @@ -75,13 +75,13 @@ def remove_connection(self, alias): if errors == 2: raise KeyError("There is no connection with alias %r." % alias) - def create_connection(self, alias="default", **kwargs): + def create_connection(self, alias="default", client=Elasticsearch, **kwargs): """ Construct an instance of ``elasticsearch.Elasticsearch`` and register it under given alias. """ kwargs.setdefault("serializer", serializer) - conn = self._conns[alias] = Elasticsearch(**kwargs) + conn = self._conns[alias] = client(**kwargs) return conn def get_connection(self, alias="default"): diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index f77667dfd..6c250e9bf 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -25,13 +25,15 @@ from elasticsearch.exceptions import NotFoundError, RequestError from six import add_metaclass, iteritems -from .connections import get_connection -from .exceptions import IllegalOperation, ValidationException -from .field import Field -from .index import Index -from .mapping import Mapping -from .search import Search -from .utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException +from elasticsearch_dsl.field import Field +from elasticsearch_dsl.index import Index +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge + +from .utils import ensure_sync_connection class MetaField(object): @@ -200,6 +202,8 @@ def get(cls, id, using=None, index=None, **kwargs): ``Elasticsearch.get`` unchanged. """ es = cls._get_connection(using) + ensure_sync_connection(es, "Document.get") + doc = es.get(index=cls._default_index(index), id=id, **kwargs) if not doc.get("found", False): return None @@ -228,7 +232,10 @@ def mget( """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + es = cls._get_connection(using) + ensure_sync_connection(es, "Document.mget") + body = { "docs": [ doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} @@ -282,6 +289,7 @@ def delete(self, using=None, index=None, **kwargs): ``Elasticsearch.delete`` unchanged. """ es = self._get_connection(using) + ensure_sync_connection(es, "Document.delete") # extract routing etc from meta doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} @@ -407,7 +415,10 @@ def update( doc_meta["if_seq_no"] = self.meta["seq_no"] doc_meta["if_primary_term"] = self.meta["primary_term"] - meta = self._get_connection(using).update( + es = self._get_connection(using) + ensure_sync_connection(es, "Document.update") + + meta = es.update( index=self._get_index(index), body=body, refresh=refresh, **doc_meta ) # update meta information from ES @@ -440,6 +451,8 @@ def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs) self.full_clean() es = self._get_connection(using) + ensure_sync_connection(es, "Document.save") + # extract routing etc from meta doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} diff --git a/elasticsearch_dsl/faceted_search.py b/elasticsearch_dsl/faceted_search.py index 9c653a85c..39d1b00ec 100644 --- a/elasticsearch_dsl/faceted_search.py +++ b/elasticsearch_dsl/faceted_search.py @@ -19,11 +19,11 @@ from six import iteritems, itervalues -from .aggs import A -from .query import MatchAll, Nested, Range, Terms -from .response import Response -from .search import Search -from .utils import AttrDict +from elasticsearch_dsl.aggs import A +from elasticsearch_dsl.query import MatchAll, Nested, Range, Terms +from elasticsearch_dsl.response import Response +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import AttrDict __all__ = [ "FacetedSearch", diff --git a/elasticsearch_dsl/index.py b/elasticsearch_dsl/index.py index 17dd93f45..6fde05247 100644 --- a/elasticsearch_dsl/index.py +++ b/elasticsearch_dsl/index.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. -from . import analysis -from .connections import get_connection -from .exceptions import IllegalOperation -from .mapping import Mapping -from .search import Search -from .update_by_query import UpdateByQuery -from .utils import merge +from elasticsearch_dsl import analysis +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.update_by_query import UpdateByQuery +from elasticsearch_dsl.utils import merge + +from .utils import ensure_sync_connection class IndexTemplate(object): @@ -50,8 +52,9 @@ def to_dict(self): return d def save(self, using=None): - es = get_connection(using or self._index._using) + ensure_sync_connection(es, "IndexTemplate.save") + return es.indices.put_template(name=self._template_name, body=self.to_dict()) @@ -101,9 +104,9 @@ def resolve_field(self, field_path): return None def load_mappings(self, using=None): - self.get_or_create_mapping().update_from_es( - self._name, using=using or self._using - ) + mapping = self.get_or_create_mapping() + + mapping.update_from_es(self._name, using=using or self._using) def clone(self, name=None, using=None): """ @@ -276,14 +279,20 @@ def create(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.create`` unchanged. """ - return self._get_connection(using).indices.create( - index=self._name, body=self.to_dict(), **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.create") + + return es.indices.create(index=self._name, body=self.to_dict(), **kwargs) def is_closed(self, using=None): - state = self._get_connection(using).cluster.state( - index=self._name, metric="metadata" + es = self._get_connection(using) + ensure_sync_connection(es, "Index.is_closed") + + state = es.cluster.state( + index=self._name, + metric="metadata", ) + return state["metadata"]["indices"][self._name]["state"] == "close" def save(self, using=None): @@ -348,7 +357,10 @@ def analyze(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.analyze`` unchanged. """ - return self._get_connection(using).indices.analyze(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.analyze") + + return es.indices.analyze(index=self._name, **kwargs) def refresh(self, using=None, **kwargs): """ @@ -357,7 +369,10 @@ def refresh(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.refresh`` unchanged. """ - return self._get_connection(using).indices.refresh(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.refresh") + + return es.indices.refresh(index=self._name, **kwargs) def flush(self, using=None, **kwargs): """ @@ -366,7 +381,10 @@ def flush(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.flush`` unchanged. """ - return self._get_connection(using).indices.flush(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.flush") + + return es.indices.flush(index=self._name, **kwargs) def get(self, using=None, **kwargs): """ @@ -375,7 +393,10 @@ def get(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get`` unchanged. """ - return self._get_connection(using).indices.get(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get") + + return es.indices.get(index=self._name, **kwargs) def open(self, using=None, **kwargs): """ @@ -384,7 +405,10 @@ def open(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.open`` unchanged. """ - return self._get_connection(using).indices.open(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.open") + + return es.indices.open(index=self._name, **kwargs) def close(self, using=None, **kwargs): """ @@ -393,7 +417,10 @@ def close(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.close`` unchanged. """ - return self._get_connection(using).indices.close(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.close") + + return es.indices.close(index=self._name, **kwargs) def delete(self, using=None, **kwargs): """ @@ -402,7 +429,10 @@ def delete(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.delete`` unchanged. """ - return self._get_connection(using).indices.delete(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.delete") + + return es.indices.delete(index=self._name, **kwargs) def exists(self, using=None, **kwargs): """ @@ -411,7 +441,10 @@ def exists(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.exists`` unchanged. """ - return self._get_connection(using).indices.exists(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.exists") + + return es.indices.exists(index=self._name, **kwargs) def exists_type(self, using=None, **kwargs): """ @@ -420,9 +453,10 @@ def exists_type(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.exists_type`` unchanged. """ - return self._get_connection(using).indices.exists_type( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.exists_type") + + return es.indices.exists_type(index=self._name, **kwargs) def put_mapping(self, using=None, **kwargs): """ @@ -431,9 +465,10 @@ def put_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_mapping`` unchanged. """ - return self._get_connection(using).indices.put_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_mapping") + + return es.indices.put_mapping(index=self._name, **kwargs) def get_mapping(self, using=None, **kwargs): """ @@ -442,9 +477,10 @@ def get_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_mapping`` unchanged. """ - return self._get_connection(using).indices.get_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_mapping") + + return es.indices.get_mapping(index=self._name, **kwargs) def get_field_mapping(self, using=None, **kwargs): """ @@ -453,9 +489,10 @@ def get_field_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_field_mapping`` unchanged. """ - return self._get_connection(using).indices.get_field_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_field_mapping") + + return es.indices.get_field_mapping(index=self._name, **kwargs) def put_alias(self, using=None, **kwargs): """ @@ -464,7 +501,10 @@ def put_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_alias`` unchanged. """ - return self._get_connection(using).indices.put_alias(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_alias") + + return es.indices.put_alias(index=self._name, **kwargs) def exists_alias(self, using=None, **kwargs): """ @@ -484,7 +524,10 @@ def get_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_alias`` unchanged. """ - return self._get_connection(using).indices.get_alias(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_alias") + + return es.indices.get_alias(index=self._name, **kwargs) def delete_alias(self, using=None, **kwargs): """ @@ -493,9 +536,10 @@ def delete_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.delete_alias`` unchanged. """ - return self._get_connection(using).indices.delete_alias( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.delete_alias") + + return es.indices.delete_alias(index=self._name, **kwargs) def get_settings(self, using=None, **kwargs): """ @@ -504,9 +548,10 @@ def get_settings(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_settings`` unchanged. """ - return self._get_connection(using).indices.get_settings( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_settings") + + return es.indices.get_settings(index=self._name, **kwargs) def put_settings(self, using=None, **kwargs): """ @@ -515,9 +560,10 @@ def put_settings(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_settings`` unchanged. """ - return self._get_connection(using).indices.put_settings( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_settings") + + return es.indices.put_settings(index=self._name, **kwargs) def stats(self, using=None, **kwargs): """ @@ -526,7 +572,10 @@ def stats(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.stats`` unchanged. """ - return self._get_connection(using).indices.stats(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.stats") + + return es.indices.stats(index=self._name, **kwargs) def segments(self, using=None, **kwargs): """ @@ -536,7 +585,10 @@ def segments(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.segments`` unchanged. """ - return self._get_connection(using).indices.segments(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.segments") + + return es.indices.segments(index=self._name, **kwargs) def validate_query(self, using=None, **kwargs): """ @@ -545,9 +597,10 @@ def validate_query(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.validate_query`` unchanged. """ - return self._get_connection(using).indices.validate_query( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.validate_query") + + return es.indices.validate_query(index=self._name, **kwargs) def clear_cache(self, using=None, **kwargs): """ @@ -556,9 +609,10 @@ def clear_cache(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.clear_cache`` unchanged. """ - return self._get_connection(using).indices.clear_cache( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.clear_cache") + + return es.indices.clear_cache(index=self._name, **kwargs) def recovery(self, using=None, **kwargs): """ @@ -568,7 +622,10 @@ def recovery(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.recovery`` unchanged. """ - return self._get_connection(using).indices.recovery(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.recovery") + + return es.indices.recovery(index=self._name, **kwargs) def upgrade(self, using=None, **kwargs): """ @@ -577,7 +634,10 @@ def upgrade(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.upgrade`` unchanged. """ - return self._get_connection(using).indices.upgrade(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.upgrade") + + return es.indices.upgrade(index=self._name, **kwargs) def get_upgrade(self, using=None, **kwargs): """ @@ -586,9 +646,10 @@ def get_upgrade(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_upgrade`` unchanged. """ - return self._get_connection(using).indices.get_upgrade( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_upgrade") + + return es.indices.get_upgrade(index=self._name, **kwargs) def flush_synced(self, using=None, **kwargs): """ @@ -598,9 +659,10 @@ def flush_synced(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.flush_synced`` unchanged. """ - return self._get_connection(using).indices.flush_synced( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.flush_synced") + + return es.indices.flush_synced(index=self._name, **kwargs) def shard_stores(self, using=None, **kwargs): """ @@ -612,9 +674,10 @@ def shard_stores(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.shard_stores`` unchanged. """ - return self._get_connection(using).indices.shard_stores( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.shard_stores") + + return es.indices.shard_stores(index=self._name, **kwargs) def forcemerge(self, using=None, **kwargs): """ @@ -630,9 +693,10 @@ def forcemerge(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.forcemerge`` unchanged. """ - return self._get_connection(using).indices.forcemerge( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.forcemerge") + + return es.indices.forcemerge(index=self._name, **kwargs) def shrink(self, using=None, **kwargs): """ @@ -649,4 +713,7 @@ def shrink(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.shrink`` unchanged. """ - return self._get_connection(using).indices.shrink(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.shrink") + + return es.indices.shrink(index=self._name, **kwargs) diff --git a/elasticsearch_dsl/mapping.py b/elasticsearch_dsl/mapping.py index 6d1bc8bfd..da8f7d710 100644 --- a/elasticsearch_dsl/mapping.py +++ b/elasticsearch_dsl/mapping.py @@ -24,9 +24,11 @@ from six import iteritems, itervalues -from .connections import get_connection -from .field import Nested, Text, construct_field -from .utils import DslBase +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.field import Nested, Text, construct_field +from elasticsearch_dsl.utils import DslBase + +from .utils import ensure_sync_connection META_FIELDS = frozenset( ( @@ -108,6 +110,7 @@ def _clone(self): def from_es(cls, index, using="default"): m = cls() m.update_from_es(index, using) + return m def resolve_nested(self, field_path): @@ -169,6 +172,8 @@ def save(self, index, using="default"): def update_from_es(self, index, using="default"): es = get_connection(using) + ensure_sync_connection(es, "Mapping.update_from_es") + raw = es.indices.get_mapping(index=index) _, raw = raw.popitem() self._update_from_dict(raw["mappings"]) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index b8323c180..b73ceb629 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -26,12 +26,14 @@ from elasticsearch.helpers import scan from six import iteritems, string_types -from .aggs import A, AggBase -from .connections import get_connection -from .exceptions import IllegalOperation -from .query import Bool, Q -from .response import Hit, Response -from .utils import AttrDict, DslBase +from elasticsearch_dsl.aggs import A, AggBase +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import Hit, Response +from elasticsearch_dsl.utils import AttrDict, DslBase + +from .utils import ensure_sync_connection class QueryProxy(object): @@ -710,10 +712,13 @@ def execute(self, ignore_cache=False): """ if ignore_cache or not hasattr(self, "_response"): es = get_connection(self._using) + ensure_sync_connection(es, "Search.execute") self._response = self._response_class( - self, es.search(index=self._index, body=self.to_dict(), **self._params) + self, + es.search(index=self._index, body=self.to_dict(), **self._params), ) + return self._response def scan(self): @@ -727,6 +732,7 @@ def scan(self): """ es = get_connection(self._using) + ensure_sync_connection(es, "Search.scan") for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): yield self._get_result(hit) @@ -735,8 +741,8 @@ def delete(self): """ delete() executes the query by delegating to delete_by_query() """ - es = get_connection(self._using) + ensure_sync_connection(es, "Search.delete") return AttrDict( es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) @@ -795,6 +801,7 @@ def execute(self, ignore_cache=False, raise_on_error=True): """ if ignore_cache or not hasattr(self, "_response"): es = get_connection(self._using) + ensure_sync_connection(es, "MultiSearch.execute") responses = es.msearch( index=self._index, body=self.to_dict(), **self._params diff --git a/elasticsearch_dsl/update_by_query.py b/elasticsearch_dsl/update_by_query.py index 1d257b92f..e0de1f13a 100644 --- a/elasticsearch_dsl/update_by_query.py +++ b/elasticsearch_dsl/update_by_query.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. -from .connections import get_connection -from .query import Bool, Q -from .response import UpdateByQueryResponse -from .search import ProxyDescriptor, QueryProxy, Request +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search import ProxyDescriptor, QueryProxy, Request + +from .utils import ensure_sync_connection class UpdateByQuery(Request): @@ -152,6 +154,7 @@ def execute(self): the data. """ es = get_connection(self._using) + ensure_sync_connection(es, "MultiSearch.execute") self._response = self._response_class( self, diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 50849773a..f353f05f5 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -24,6 +24,7 @@ from copy import copy +from elasticsearch import Elasticsearch from six import add_metaclass, iteritems from six.moves import map @@ -544,6 +545,19 @@ def full_clean(self): self.clean() +def ensure_sync_connection(es, fn_label): + # Allow "Mock" objects to be passed during testing. + if es.__class__.__name__ == "Mock": + return + + if not isinstance(es, Elasticsearch): + raise TypeError( + "{} can only be used with the elasticsearch.Elasticsearch client".format( + fn_label, + ) + ) + + def merge(data, new_data, raise_on_conflict=False): if not ( isinstance(data, (AttrDict, collections_abc.Mapping)) diff --git a/setup.py b/setup.py index 9815739a9..797ebbf4e 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,10 @@ "coverage<5.0.0", ] +async_requires = [ + 'aiohttp>=3,<4; python_version>="3.6"', +] + setup( name="elasticsearch-dsl", description="Python client for Elasticsearch", @@ -78,6 +82,11 @@ ], install_requires=install_requires, test_suite="test_elasticsearch_dsl.run_tests.run_all", - tests_require=tests_require, - extras_require={"develop": tests_require + ["sphinx", "sphinx_rtd_theme"]}, + tests_require=tests_require + async_requires, + extras_require={ + "async": async_requires, + "develop": ( + tests_require + async_requires + ["sphinx", "sphinx_rtd_theme", "unasync"] + ), + }, ) diff --git a/test_elasticsearch_dsl/test_connections.py b/test_elasticsearch_dsl/test_connections.py index 278760cc3..1db54bbde 100644 --- a/test_elasticsearch_dsl/test_connections.py +++ b/test_elasticsearch_dsl/test_connections.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import sys + +import pytest from elasticsearch import Elasticsearch from pytest import raises @@ -81,9 +84,22 @@ def test_create_connection_constructs_client(): c.create_connection("testing", hosts=["es.com"]) con = c.get_connection("testing") + assert isinstance(c.get_connection("testing"), Elasticsearch) assert [{"host": "es.com"}] == con.transport.hosts +@pytest.mark.skipif( + sys.version_info < (3, 6), reason="Async features require Python 3.6 or higher" +) +def test_create_connection_constructs_async_client(): + from elasticsearch import AsyncElasticsearch + + c = connections.Connections() + c.create_connection("testing", client=AsyncElasticsearch, hosts=["es.com"]) + + assert isinstance(c.get_connection("testing"), AsyncElasticsearch) + + def test_create_connection_adds_our_serializer(): c = connections.Connections() c.create_connection("testing", hosts=["es.com"]) diff --git a/utils/generate-sync.py b/utils/generate-sync.py new file mode 100644 index 000000000..df8bf2061 --- /dev/null +++ b/utils/generate-sync.py @@ -0,0 +1,68 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from pathlib import Path + +import black +import unasync +from click.testing import CliRunner + +CODE_ROOT = Path(__file__).absolute().parent.parent + + +def _blacken(filename): + runner = CliRunner() + result = runner.invoke(black.main, [str(filename)]) + assert result.exit_code == 0, result.output + + +def generate_sync(): + additional_replacements = { + "_async": "", + "async_scan": "scan", + "ensure_async_connection": "ensure_sync_connection", + } + + rules = [ + unasync.Rule( + fromdir="/_async/", + todir="/", + additional_replacements=additional_replacements, + ), + ] + + filepaths = [] + for root, _, filenames in os.walk(CODE_ROOT / "elasticsearch_dsl/_async"): + for filename in filenames: + if ( + filename.rpartition(".")[-1] + in ( + "py", + "pyi", + ) + and not filename.startswith("__init__.py") + and not filename.startswith("utils.py") + ): + filepaths.append(os.path.join(root, filename)) + + unasync.unasync_files(filepaths, rules) + _blacken(CODE_ROOT / "elasticsearch_dsl") + + +if __name__ == "__main__": + generate_sync()