From decb6bfb17b3f167dc06c90de972f3e90e8c0107 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Mon, 5 Oct 2020 16:42:43 -0600 Subject: [PATCH 01/11] Add Async I/O Support This commit adds support for async I/O that mirrors the support added to elasticsearch-py in elastic/elasticsearch-py#1203. Changes: * A new `client` argument was added to `elasticsearch_dsl.connections.create_connection` to allow users to provide the `elasticsearch._async.AsyncElasticsearch` as their preferred client class. Passing `AsyncElasticsearch` will enable asynchronous behavior in `elasticsearch_dsl`. * Async versions of the `FacetedSearch`, `Index`, `Mapping`, `Search`, and `UpdateByQuery` classes have been added to elasticsearch_dsl._async. The paths for these classes mirror the paths for their sync versions. These classes defer to their respective sync classes for all methods that don't perform I/O. * Async versions of `Document.get`, `Document.init`, `Document.mget`, `Document.delete`, `Document.save`, and `Document.update` have been added to the `Document` class with the following names: * `Document.delete` -> `Document.delete_async` * `Document.get` -> `Document.get_async` * `Document.init` -> `Document.init_async` * `Document.mget` -> `Document.mget_async` * `Document.save` -> `Document.save_async` * `Document.update` -> `Document.update_async` * Where possible, the existing methods have been refactored to re-use their existing implementation instead of creating duplication. Closes #1355. --- elasticsearch_dsl/_async/faceted_search.py | 11 + elasticsearch_dsl/_async/index.py | 473 ++++++++++++++++++++ elasticsearch_dsl/_async/mapping.py | 27 ++ elasticsearch_dsl/_async/search.py | 80 ++++ elasticsearch_dsl/_async/update_by_query.py | 23 + elasticsearch_dsl/_async/utils.py | 9 + elasticsearch_dsl/connections.py | 4 +- elasticsearch_dsl/document.py | 430 ++++++++++++++---- elasticsearch_dsl/search.py | 31 +- test_elasticsearch_dsl/test_connections.py | 12 +- test_elasticsearch_dsl/test_document.py | 4 + 11 files changed, 998 insertions(+), 106 deletions(-) create mode 100644 elasticsearch_dsl/_async/faceted_search.py create mode 100644 elasticsearch_dsl/_async/index.py create mode 100644 elasticsearch_dsl/_async/mapping.py create mode 100644 elasticsearch_dsl/_async/search.py create mode 100644 elasticsearch_dsl/_async/update_by_query.py create mode 100644 elasticsearch_dsl/_async/utils.py diff --git a/elasticsearch_dsl/_async/faceted_search.py b/elasticsearch_dsl/_async/faceted_search.py new file mode 100644 index 000000000..caaaa8b91 --- /dev/null +++ b/elasticsearch_dsl/_async/faceted_search.py @@ -0,0 +1,11 @@ +from elasticsearch_dsl.faceted_search import FacetedSearch + + +class AsyncFacetedSearch(FacetedSearch): + async def execute(self): + """ + Asynchronously execute the search and return the response. + """ + r = await self._s.execute() + r._faceted_search = self + return r diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py new file mode 100644 index 000000000..67f052df6 --- /dev/null +++ b/elasticsearch_dsl/_async/index.py @@ -0,0 +1,473 @@ +from elasticsearch import AsyncElasticsearch + +from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.index import Index, IndexTemplate + + +class AsyncIndex(Index): + async def analyze(self, using=None, **kwargs): + """ + Asynchronously perform the analysis process on a text and return the tokens + breakdown of the text. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.analyze`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.analyze") + + return await es.indices.analyze(index=self._index, **kwargs) + + async def clear_cache(self, using=None, **kwargs): + """ + Asynchronously clear all caches or specific cached associated with the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.clear_cache`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.clear_cache") + + return await es.indices.clear_cache(index=self._index, **kwargs) + + async def close(self, using=None, **kwargs): + """ + Asynchronously closes the index in Elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.close`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.close") + + return await es.indices.close(index=self._index, **kwargs) + + async def create(self, using=None, **kwargs): + """ + Asynchronously creates the index in Elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.create`` unchanged. + """ + es = get_connection(using) + ensure_async_connection(es, "AsyncIndex.create") + + return await es.indices.create( + index=self._name, + body=self.to_dict(), + **kwargs, + ) + + async def delete(self, using=None, **kwargs): + """ + Asynchronously deletes the index in Elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.delete`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.delete") + + return await es.indices.delete(index=self._index, **kwargs) + + async def delete_alias(self, using=None, **kwargs): + """ + Asynchronously deletes a specific alias. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.delete_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.delete_alias") + + return await es.indices.delete_alias(index=self._index, **kwargs) + + async def exists(self, using=None, **kwargs): + """ + Asynchronously queries Elasticsearch for whether this index exists. Returns + ``True`` if the index already exists in Elasticsearch, otherwise ``False``. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.exists`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.exists") + + return await es.indices.exists(index=self._index, **kwargs) + + async def exists_type(self, using=None, **kwargs): + """ + Asynchronously queries Elasticsearch for whether a type or set of types exists + in the index. Returns ``True`` if the type/types exist, otherwise ``False``. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.exists_type`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.exists_type") + + return await es.indices.exists_type(index=self._index, **kwargs) + + async def flush(self, using=None, **kwargs): + """ + Asynchronously performs a flush operation on the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.flush`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.flush") + + return await es.indices.flush(index=self._index, **kwargs) + + async def flush_synced(self, using=None, **kwargs): + """ + Asynchronously performs a normal flush, then adds a unique marker (sync_id) to + all shards. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.flush_synced`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.flush_synced") + + return await es.indices.flush_synced(index=self._index, **kwargs) + + async def forcemerge(self, using=None, **kwargs): + """ + Asynchronously calls the force merge API. + + The force merge API allows to force merging of the index through an API. The + merge relates to the number of segments a Lucene index holds within each shard. + The force merge operation allows to reduce the number of segments by merging + them. + + This call will block until the merge is complete. If the http connection is + lost, the request will continue in the background, and any new requests will + block until the previous force merge is complete. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.forcemerge`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.forcemerge") + + return await es.indices.forcemerge(index=self._index, **kwargs) + + async def get(self, using=None, **kwargs): + """ + Asynchronously retrieves information about the index from Elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get") + + return await es.indices.get(index=self._index, **kwargs) + + async def get_alias(self, using=None, **kwargs): + """ + Asynchronously retrieves a specific alias. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get_alias") + + return await es.indices.get_alias(index=self._index, **kwargs) + + async def get_field_mapping(self, using=None, **kwargs): + """ + Asynchronously retrieves a mapping definition for a specific field. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get_field_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get_field_mapping") + + return await es.indices.get_field_mapping(index=self._index, **kwargs) + + async def get_mapping(self, using=None, **kwargs): + """ + Asynchronously retrieves a specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get_mapping") + + return await es.indices.get_mapping(index=self._index, **kwargs) + + async def get_settings(self, using=None, **kwargs): + """ + Asynchronously retrieves the settings for the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get_settings`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get_settings") + + return await es.indices.get_settings(index=self._index, **kwargs) + + async def get_upgrade(self, using=None, **kwargs): + """ + Asynchronously monitors how much of an index is upgraded. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.get_upgrade`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.get_upgrade") + + return await es.indices.get_upgrade(index=self._index, **kwargs) + + async def is_closed(self, using=None): + """ + Asynchronously queries Elasticsearch to determine whether this index + is closed. + """ + es = get_connection(using) + ensure_async_connection(es, "AsyncIndex.is_closed") + + state = await es.cluster.state( + index=self._name, + metric="metadata", + ) + + return state["metadata"]["indices"][self._name]["state"] == "close" + + async def load_mappings(self, using=None): + mapping = self.get_or_create_mapping() + + await mapping.update_from_es(self._name, using=using or self._using) + + async def open(self, using=None, **kwargs): + """ + Asynchronously opens the index in Elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.open`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.open") + + return await es.indices.open(index=self._index, **kwargs) + + async def put_alias(self, using=None, **kwargs): + """ + Asynchronously creates an alias for the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.put_alias`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.put_alias") + + return await es.indices.put_alias(index=self._index, **kwargs) + + async def put_mapping(self, using=None, **kwargs): + """ + Asynchronously register a specific mapping definition for a specific type. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.put_mapping`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.put_mapping") + + return await es.indices.put_mapping(index=self._index, **kwargs) + + async def put_settings(self, using=None, **kwargs): + """ + Asynchronously changes specific index-level settings. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.put_settings`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.put_settings") + + return await es.indices.put_settings(index=self._index, **kwargs) + + async def recovery(self, using=None, **kwargs): + """ + Asynchronously provides insight into ongoing shard recoveries for the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.recovery`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.recovery") + + return await es.indices.recovery(index=self._index, **kwargs) + + async def refresh(self, using=None, **kwargs): + """ + Asynchronously performs a refresh operation on the index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.refresh`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.refresh") + + return await es.indices.refresh(index=self._index, **kwargs) + + async def save(self, using=None): + """ + Asynchronously sync the index definition with Elasticsearch, creating the index + if it doesn't exist and updating its settings and mappings if it does. + + Note: Some settings and mapping changes cannot be done on an open index (or at + all on an existing index) and for those this method will fail with the + underlying exception. + """ + + if not await self.exists(using=using): + return await self.create(using=using) + + body = self.to_dict() + settings = body.pop("settings", {}) + analysis = settings.pop("analysis", None) + current_settings = self.get_settings(using=using)[self._name]["settings"][ + "index" + ] + + if analysis: + if await self.is_closed(using=using): + # closed index, update away + settings["analysis"] = analysis + else: + # compare analysis definition, if all analysis objects are + # already defined as requested, skip analysis update and + # proceed, otherwise raise IllegalOperation + existing_analysis = current_settings.get("analysis", {}) + if any( + existing_analysis.get(section, {}).get(k, None) + != analysis[section][k] + for section in analysis + for k in analysis[section] + ): + raise IllegalOperation( + "You cannot update analysis configuration on an open index, " + "you need to close index %s first." % self._name + ) + + # try and update the settings + if settings: + settings = settings.copy() + for k, v in list(settings.items()): + if k in current_settings and current_settings[k] == str(v): + del settings[k] + + if settings: + await self.put_settings(using=using, body=settings) + + # update the mappings, any conflict in the mappings will result in an + # exception + mappings = body.pop("mappings", {}) + if mappings: + await self.put_mapping(using=using, body=mappings) + + async def segments(self, using=None, **kwargs): + """ + Asynchronously provides low-level segments information that a Lucene index + (shard level) is built with. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.segments`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.segments") + + return await es.indices.segments(index=self._index, **kwargs) + + async def shard_stores(self, using=None, **kwargs): + """ + Asynchronously provides store information for shard copies of the index. Store + information reports on which nodes shard copies exist, the shard copy version, + indicating how recent they are, and any exceptions encountered while opening + the shard index or from earlier engine failure. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.shard_stores`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.shard_stores") + + return await es.indices.shard_stores(index=self._index, **kwargs) + + async def shrink(self, using=None, **kwargs): + """ + Asynchronously calls the shrink index API. + + The shrink index API allows you to shrink an existing index into a new + index with fewer primary shards. The number of primary shards in the + target index must be a factor of the shards in the source index. For + example an index with 8 primary shards can be shrunk into 4, 2 or 1 + primary shards or an index with 15 primary shards can be shrunk into 5, + 3 or 1. If the number of shards in the index is a prime number it can + only be shrunk into a single primary shard. Before shrinking, a + (primary or replica) copy of every shard in the index must be present + on the same node. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.shrink`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.shrink") + + return await es.indices.shrink(index=self._index, **kwargs) + + async def stats(self, using=None, **kwargs): + """ + Asynchronously retrieves statistics on different operations happening on the + index. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.stats`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.stats") + + return await es.indices.stats(index=self._index, **kwargs) + + async def upgrade(self, using=None, **kwargs): + """ + Asynchronously upgrades the index to the latest format. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.upgrade`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.upgrade") + + return await es.indices.upgrade(index=self._index, **kwargs) + + async def validate_query(self, using=None, **kwargs): + """ + Asynchronously validates a potentially expensive query without executing it. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.indices.validate_query`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "AsyncIndex.validate_query") + + return await es.indices.validate_query(index=self._index, **kwargs) + + +class AsyncIndexTemplate(IndexTemplate): + async def save(self, using=None): + es = get_connection(using or self._index._using) + ensure_async_connection(es, "AsyncIndexTemplate.save") + + return await es.indices.put_template( + name=self._template_name, body=self.to_dict() + ) diff --git a/elasticsearch_dsl/_async/mapping.py b/elasticsearch_dsl/_async/mapping.py new file mode 100644 index 000000000..a9ce2f1bb --- /dev/null +++ b/elasticsearch_dsl/_async/mapping.py @@ -0,0 +1,27 @@ +from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.mapping import Mapping + + +class AsyncMapping(Mapping): + @classmethod + async def from_es(cls, index, using="default"): + m = cls() + await m.update_from_es(index, using) + + return m + + async def save(self, index, using="default"): + from elasticsearch_dsl._async.index import AsyncIndex + + index = AsyncIndex(index, using=using) + index.mapping(self) + return await index.save() + + async def update_from_es(self, index, using="default"): + es = get_connection(using) + ensure_async_connection(es, "AsyncMapping.update_from_es") + + raw = await es.indices.get_mapping(index=index) + _, raw = raw.popitem() + self._update_from_dict(raw["mappings"]) diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py new file mode 100644 index 000000000..6d9249396 --- /dev/null +++ b/elasticsearch_dsl/_async/search.py @@ -0,0 +1,80 @@ +from elasticsearch import AsyncElasticsearch +from elasticsearch._async.helpers import async_scan + +from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.search import MultiSearch, Search +from elasticsearch_dsl.utils import AttrDict + + +class AsyncMultiSearch(MultiSearch): + async def execute(self, ignore_cache=False, raise_on_error=True): + """ + Execute the multi search request and return a list of search results. + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + ensure_async_connection(es, "AsyncMultiSearch.execute") + + responses = await es.msearch( + index=self._index, + body=self.to_dict(), + **self.params, + ) + + self._response = self._process_responses( + responses, raise_on_error=raise_on_error + ) + + return self._response + + +class AsyncSearch(Search): + async def __aiter__(self): + """ + Asynchronously iterates over the hits. + """ + return iter(self.execute()) + + async def execute(self, ignore_cache=False): + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + ensure_async_connection(es, "AsyncSearch.execute") + + self._response = self._response_class( + self, + await es.search(index=self._index, body=self.to_dict(), **self._params), + ) + + return self._response + + async def scan(self): + """ + Turn the search into a scan search and return a generator that will + iterate over all the documents matching the query. + + Use ``params`` method to specify any additional arguments you with to + pass to the underlying ``scan`` helper from ``elasticsearch-py`` - + https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + + """ + es = get_connection(self._using) + ensure_async_connection(es, "AsyncSearch.scan") + + for hit in await async_scan( + es, query=self.to_dict(), index=self._index, **self._params + ): + yield self._get_result(hit) + + async def delete(self): + """ + delete() executes the query by delegating to delete_by_query() + """ + es = get_connection(self._using) + ensure_async_connection(es, "AsyncSearch.delete") + + return AttrDict( + await es.delete_by_query( + index=self._index, body=self.to_dict(), **self._params + ) + ) diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py new file mode 100644 index 000000000..a743119f4 --- /dev/null +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -0,0 +1,23 @@ +from elasticsearch import AsyncElasticsearch + +from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.update_by_query import UpdateByQuery + + +class AsyncUpdateByQuery(UpdateByQuery): + async def execute(self): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + """ + es = get_connection(self._using) + ensure_async_connection(es, "AsyncMultiSearch.execute") + + self._response = self._response_class( + self, + await es.update_by_query( + index=self._index, body=self.to_dict(), **self._params + ), + ) + return self._response diff --git a/elasticsearch_dsl/_async/utils.py b/elasticsearch_dsl/_async/utils.py new file mode 100644 index 000000000..e9abe4e1f --- /dev/null +++ b/elasticsearch_dsl/_async/utils.py @@ -0,0 +1,9 @@ +from elasticsearch import AsyncElasticsearch + + +def ensure_async_connection(es, fn_label): + if not isinstance(es, AsyncElasticsearch): + raise TypeError( + f"{fn_label} can only be used with the elasticsearch.AsyncElasticsearch " + "client" + ) diff --git a/elasticsearch_dsl/connections.py b/elasticsearch_dsl/connections.py index 57ba46f1d..f826517be 100644 --- a/elasticsearch_dsl/connections.py +++ b/elasticsearch_dsl/connections.py @@ -75,13 +75,13 @@ def remove_connection(self, alias): if errors == 2: raise KeyError("There is no connection with alias %r." % alias) - def create_connection(self, alias="default", **kwargs): + def create_connection(self, alias="default", client=Elasticsearch, **kwargs): """ Construct an instance of ``elasticsearch.Elasticsearch`` and register it under given alias. """ kwargs.setdefault("serializer", serializer) - conn = self._conns[alias] = Elasticsearch(**kwargs) + conn = self._conns[alias] = client(**kwargs) return conn def get_connection(self, alias="default"): diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index f77667dfd..4d257a7ab 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -22,9 +22,14 @@ from fnmatch import fnmatch +from elasticsearch import AsyncElasticsearch from elasticsearch.exceptions import NotFoundError, RequestError from six import add_metaclass, iteritems +from elasticsearch_dsl._async.utils import ensure_async_connection + +from ._async.search import AsyncSearch +from ._async.utils import ensure_async_connection from .connections import get_connection from .exceptions import IllegalOperation, ValidationException from .field import Field @@ -182,9 +187,16 @@ def search(cls, using=None, index=None): Create an :class:`~elasticsearch_dsl.Search` instance that will search over this ``Document``. """ - return Search( - using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] - ) + es = cls._get_using(using) + kwargs = { + "doc_type": [cls], + "index": cls._default_index(index), + } + + if isinstance(es, AsyncElasticsearch): + return AsyncSearch(using=es, **kwargs) + + return Search(using=es, **kwargs) @classmethod def get(cls, id, using=None, index=None, **kwargs): @@ -228,47 +240,20 @@ def mget( """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - es = cls._get_connection(using) - body = { - "docs": [ - doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} - for doc in docs - ] - } - results = es.mget(body, index=cls._default_index(index), **kwargs) - - objs, error_docs, missing_docs = [], [], [] - for doc in results["docs"]: - if doc.get("found"): - if error_docs or missing_docs: - # We're going to raise an exception anyway, so avoid an - # expensive call to cls.from_es(). - continue - - objs.append(cls.from_es(doc)) - elif doc.get("error"): - if raise_on_error: - error_docs.append(doc) - if missing == "none": - objs.append(None) + es = cls._get_connection(using) - # The doc didn't cause an error, but the doc also wasn't found. - elif missing == "raise": - missing_docs.append(doc) - elif missing == "none": - objs.append(None) + results = es.mget( + cls._build_mget_body(docs), + index=cls._default_index(index), + **kwargs, + ) - if error_docs: - error_ids = [doc["_id"] for doc in error_docs] - message = "Required routing not provided for documents %s." - message %= ", ".join(error_ids) - raise RequestError(400, message, error_docs) - if missing_docs: - missing_ids = [doc["_id"] for doc in missing_docs] - message = "Documents %s not found." % ", ".join(missing_ids) - raise NotFoundError(404, message, {"docs": missing_docs}) - return objs + return cls._parse_mget_results( + results, + missing=missing, + raise_on_error=raise_on_error, + ) def delete(self, using=None, index=None, **kwargs): """ @@ -282,15 +267,7 @@ def delete(self, using=None, index=None, **kwargs): ``Elasticsearch.delete`` unchanged. """ es = self._get_connection(using) - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - doc_meta.update(kwargs) + doc_meta = self._build_delete_doc_meta(**kwargs) es.delete(index=self._get_index(index), **doc_meta) def to_dict(self, include_meta=False, skip_empty=True): @@ -359,6 +336,278 @@ def update( :return operation result noop/updated """ + body, doc_meta = self._build_update_body_and_meta( + detect_noop=detect_noop, + doc_as_upsert=doc_as_upsert, + retry_on_conflict=retry_on_conflict, + script=script, + script_id=script_id, + scripted_upsert=scripted_upsert, + upsert=upsert, + **fields, + ) + + meta = self._get_connection(using).update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + self._update_doc_meta(meta) + + return meta["result"] + + def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + doc_meta = self._build_save_doc_meta(**kwargs) + meta = es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta, + ) + self._update_doc_meta(meta) + + return meta["result"] + + @classmethod + async def get_async(cls, id, using=None, index=None, **kwargs): + """ + Asynchronously retrieves a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to be retrieved + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.get`` unchanged. + """ + es = cls._get_connection(using) + ensure_async_connection(es, "Document.get_async") + + doc = await es.get(index=cls._default_index(index), id=id, **kwargs) + if not doc.get("found", False): + return None + return cls.from_es(doc) + + @classmethod + async def init_async(cls, index=None, using=None): + """ + Asynchronously creates the index and populates the mappings in Elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + await i.save(using=using) + + @classmethod + async def mget_async( + cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs + ): + r""" + Asynchronously retrieves multiple documents by their ``id``\s. Returns a list + of instances in the same order as requested. + + :arg docs: list of ``id``\s of the documents to be retrieved or a list + of document specifications as per + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg missing: what to do when one of the documents requested is not + found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise + ``NotFoundError``) or ``'skip'`` (ignore the missing document). + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.mget`` unchanged. + """ + if missing not in ("raise", "skip", "none"): + raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + + es = cls._get_connection(using) + ensure_async_connection(es, "Document.mget_async") + + results = await es.mget( + cls._build_mget_body(docs), + index=cls._default_index(index), + **kwargs, + ) + + return cls._parse_mget_results( + results, + missing=missing, + raise_on_error=raise_on_error, + ) + + async def delete_async(self, using=None, index=None, **kwargs): + """ + Asynchronously deletes the instance in Elasticsearch. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.delete`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Document.delete_async") + doc_meta = self._build_delete_doc_meta(**kwargs) + await es.delete(index=self._get_index(index), **doc_meta) + + async def save_async( + self, using=None, index=None, validate=True, skip_empty=True, **kwargs + ): + """ + Asyncrhonously saves the document into Elasticsearch. If the document doesn't + exist it is created, otherwise it is overwritten. Returns ``True`` if this + operation resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + ensure_async_connection(es, "Document.save_async") + + doc_meta = self._build_save_doc_meta(**kwargs) + meta = await es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta, + ) + self._update_doc_meta(meta) + + return meta["result"] + + async def update_async( + self, + using=None, + index=None, + detect_noop=True, + doc_as_upsert=False, + refresh=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + **fields + ): + """ + Asynchronously performs a partial update of the document using the provided + fields. + + doc = MyDocument(title='Document Title!') + doc.save() + doc.update(title='New Document Title!') + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg detect_noop: Set to ``False`` to disable noop detection. + :arg refresh: Control when the changes made by this request are visible + to search. Set to ``True`` for immediate effect. + :arg retry_on_conflict: In between the get and indexing phases of the + update, it is possible that another process might have already + updated the same document. By default, the update will fail with a + version conflict exception. The retry_on_conflict parameter + controls how many times to retry the update before finally throwing + an exception. + :arg doc_as_upsert: Instead of sending a partial doc plus an upsert + doc, setting doc_as_upsert to true will use the contents of doc as + the upsert value + + :return operation result noop/updated + """ + body, doc_meta = self._build_update_body_and_meta( + detect_noop=detect_noop, + doc_as_upsert=doc_as_upsert, + retry_on_conflict=retry_on_conflict, + script=script, + script_id=script_id, + scripted_upsert=scripted_upsert, + upsert=upsert, + **fields, + ) + + es = self._get_connection(using) + ensure_async_connection(es, "Document.update_async") + + meta = await es.update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + self._update_doc_meta(meta) + + return meta["result"] + + def _build_delete_doc_meta(self, **kwargs): + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + + return doc_meta + + def _build_save_doc_meta(self, **kwargs): + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + + return doc_meta + + def _build_update_body_and_meta( + self, + detect_noop=True, + doc_as_upsert=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + **fields + ): body = { "doc_as_upsert": doc_as_upsert, "detect_noop": detect_noop, @@ -407,56 +656,55 @@ def update( doc_meta["if_seq_no"] = self.meta["seq_no"] doc_meta["if_primary_term"] = self.meta["primary_term"] - meta = self._get_connection(using).update( - index=self._get_index(index), body=body, refresh=refresh, **doc_meta - ) - # update meta information from ES - for k in META_FIELDS: - if "_" + k in meta: - setattr(self.meta, k, meta["_" + k]) + return body, doc_meta - return meta["result"] + @classmethod + def _build_mget_body(cls, docs): + return { + "docs": [ + doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} + for doc in docs + ] + } - def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): - """ - Save the document into elasticsearch. If the document doesn't exist it - is created, it is overwritten otherwise. Returns ``True`` if this - operations resulted in new document being created. + @classmethod + def _parse_mget_results(cls, results, missing="none", raise_on_error=True): + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg validate: set to ``False`` to skip validating the document - :arg skip_empty: if set to ``False`` will cause empty values (``None``, - ``[]``, ``{}``) to be left on the document. Those values will be - stripped out otherwise as they make no difference in elasticsearch. + objs.append(cls.from_es(doc)) - Any additional keyword arguments will be passed to - ``Elasticsearch.index`` unchanged. + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) - :return operation result created/updated - """ - if validate: - self.full_clean() + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) - es = self._get_connection(using) - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = "Documents %s not found." % ", ".join(missing_ids) + raise NotFoundError(404, message, {"docs": missing_docs}) - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] + return objs - doc_meta.update(kwargs) - meta = es.index( - index=self._get_index(index), - body=self.to_dict(skip_empty=skip_empty), - **doc_meta - ) + def _update_doc_meta(self, meta): # update meta information from ES for k in META_FIELDS: if "_" + k in meta: setattr(self.meta, k, meta["_" + k]) - - return meta["result"] diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index b8323c180..9fef79f11 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -797,19 +797,26 @@ def execute(self, ignore_cache=False, raise_on_error=True): es = get_connection(self._using) responses = es.msearch( - index=self._index, body=self.to_dict(), **self._params + index=self._index, body=self.to_dict(), **self.params ) - out = [] - for s, r in zip(self._searches, responses["responses"]): - if r.get("error", False): - if raise_on_error: - raise TransportError("N/A", r["error"]["type"], r["error"]) - r = None - else: - r = Response(s, r) - out.append(r) - - self._response = out + self._response = self._process_responses( + responses, raise_on_error=raise_on_error + ) return self._response + + def _process_responses(self, responses, raise_on_error=True): + out = [] + + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise TransportError("N/A", r["error"]["type"], r["error"]) + r = None + else: + r = Response(s, r) + + out.append(r) + + return out diff --git a/test_elasticsearch_dsl/test_connections.py b/test_elasticsearch_dsl/test_connections.py index 278760cc3..5325dae89 100644 --- a/test_elasticsearch_dsl/test_connections.py +++ b/test_elasticsearch_dsl/test_connections.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch import Elasticsearch +import asyncio + +from elasticsearch import AsyncElasticsearch, Elasticsearch from pytest import raises from elasticsearch_dsl import connections, serializer @@ -81,9 +83,17 @@ def test_create_connection_constructs_client(): c.create_connection("testing", hosts=["es.com"]) con = c.get_connection("testing") + assert isinstance(c.get_connection("testing"), Elasticsearch) assert [{"host": "es.com"}] == con.transport.hosts +def test_create_connection_constructs_async_client(): + c = connections.Connections() + c.create_connection("testing", client=AsyncElasticsearch, hosts=["es.com"]) + + assert isinstance(c.get_connection("testing"), AsyncElasticsearch) + + def test_create_connection_adds_our_serializer(): c = connections.Connections() c.create_connection("testing", hosts=["es.com"]) diff --git a/test_elasticsearch_dsl/test_document.py b/test_elasticsearch_dsl/test_document.py index 5e34f0dbb..3fcb50eda 100644 --- a/test_elasticsearch_dsl/test_document.py +++ b/test_elasticsearch_dsl/test_document.py @@ -21,6 +21,7 @@ from datetime import datetime from hashlib import md5 +from elasticsearch import AsyncElasticsearch from pytest import raises from elasticsearch_dsl import ( @@ -29,11 +30,14 @@ Mapping, Range, analyzer, + connections, document, field, utils, ) +from elasticsearch_dsl._async.search import AsyncSearch from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException +from elasticsearch_dsl.search import Search class MyInner(InnerDoc): From 5017105b62f206154f115f319387a7aeffd55367 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Mon, 5 Oct 2020 17:12:31 -0600 Subject: [PATCH 02/11] Typo: self.params -> self._params --- elasticsearch_dsl/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 9fef79f11..7591a1584 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -797,7 +797,7 @@ def execute(self, ignore_cache=False, raise_on_error=True): es = get_connection(self._using) responses = es.msearch( - index=self._index, body=self.to_dict(), **self.params + index=self._index, body=self.to_dict(), **self._params ) self._response = self._process_responses( From 3f96db6c48f65ac4f6dd12c3edcea14b98f2f288 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Mon, 5 Oct 2020 17:26:00 -0600 Subject: [PATCH 03/11] Add Python version handling and "elasticsearch[async]" extra --- elasticsearch_dsl/__init__.py | 30 ++++++++++++++++++++++++++++++ setup.py | 11 +++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index facddf7bd..f1d9454ad 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import sys + from . import connections from .aggs import A from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer @@ -159,3 +161,31 @@ "token_filter", "tokenizer", ] + + +try: + # Asyncio only supported in Python 3.6+ + if sys.version_info < (3, 6): + raise ImportError + + from elasticsearch_dsl._async.faceted_search import AsyncFacetedSearch + from elasticsearch_dsl._async.index import AsyncIndex, AsyncIndexTemplate + from elasticsearch_dsl._async.mapping import AsyncMapping + from elasticsearch_dsl._async.search import AsyncMultiSearch, AsyncSearch + from elasticsearch_dsl._async.update_by_query import AsyncUpdateByQuery + + __all__ = sorted( + __all__ + + [ + "AsyncFacetedSearch", + "AsyncIndex", + "AsyncIndexTemplate", + "AsyncMapping", + "AsyncMultiSearch", + "AsyncSearch", + "AsyncUpdateByQuery", + ] + ) + +except (ImportError, SyntaxError): + pass diff --git a/setup.py b/setup.py index 9815739a9..064d90222 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,10 @@ "coverage<5.0.0", ] +async_requires = [ + "aiohttp>=3,<4", +] + setup( name="elasticsearch-dsl", description="Python client for Elasticsearch", @@ -78,6 +82,9 @@ ], install_requires=install_requires, test_suite="test_elasticsearch_dsl.run_tests.run_all", - tests_require=tests_require, - extras_require={"develop": tests_require + ["sphinx", "sphinx_rtd_theme"]}, + tests_require=tests_require + async_requires, + extras_require={ + "async": async_requires, + "develop": tests_require + async_requires + ["sphinx", "sphinx_rtd_theme"], + }, ) From 3c347c1e80f0afc63830643ee50cfed94508e3bb Mon Sep 17 00:00:00 2001 From: James Brewer Date: Mon, 5 Oct 2020 17:57:35 -0600 Subject: [PATCH 04/11] Make the linter happy again --- elasticsearch_dsl/_async/index.py | 3 +-- elasticsearch_dsl/_async/search.py | 1 - elasticsearch_dsl/_async/update_by_query.py | 2 -- elasticsearch_dsl/document.py | 2 -- test_elasticsearch_dsl/test_connections.py | 2 -- test_elasticsearch_dsl/test_document.py | 4 ---- 6 files changed, 1 insertion(+), 13 deletions(-) diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index 67f052df6..880437116 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -1,7 +1,6 @@ -from elasticsearch import AsyncElasticsearch - from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation from elasticsearch_dsl.index import Index, IndexTemplate diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index 6d9249396..d054460aa 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -1,4 +1,3 @@ -from elasticsearch import AsyncElasticsearch from elasticsearch._async.helpers import async_scan from elasticsearch_dsl._async.utils import ensure_async_connection diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py index a743119f4..b52386b21 100644 --- a/elasticsearch_dsl/_async/update_by_query.py +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -1,5 +1,3 @@ -from elasticsearch import AsyncElasticsearch - from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection from elasticsearch_dsl.update_by_query import UpdateByQuery diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index 4d257a7ab..0e0374d4b 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -26,8 +26,6 @@ from elasticsearch.exceptions import NotFoundError, RequestError from six import add_metaclass, iteritems -from elasticsearch_dsl._async.utils import ensure_async_connection - from ._async.search import AsyncSearch from ._async.utils import ensure_async_connection from .connections import get_connection diff --git a/test_elasticsearch_dsl/test_connections.py b/test_elasticsearch_dsl/test_connections.py index 5325dae89..f91471c15 100644 --- a/test_elasticsearch_dsl/test_connections.py +++ b/test_elasticsearch_dsl/test_connections.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import asyncio - from elasticsearch import AsyncElasticsearch, Elasticsearch from pytest import raises diff --git a/test_elasticsearch_dsl/test_document.py b/test_elasticsearch_dsl/test_document.py index 3fcb50eda..5e34f0dbb 100644 --- a/test_elasticsearch_dsl/test_document.py +++ b/test_elasticsearch_dsl/test_document.py @@ -21,7 +21,6 @@ from datetime import datetime from hashlib import md5 -from elasticsearch import AsyncElasticsearch from pytest import raises from elasticsearch_dsl import ( @@ -30,14 +29,11 @@ Mapping, Range, analyzer, - connections, document, field, utils, ) -from elasticsearch_dsl._async.search import AsyncSearch from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException -from elasticsearch_dsl.search import Search class MyInner(InnerDoc): From d7dae5ac5848e20e1df5c95397a2e934f3ff6171 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 9 Oct 2020 07:12:54 -0600 Subject: [PATCH 05/11] Move async/await into another file to prevent syntax errors --- elasticsearch_dsl/_async/document.py | 188 ++++++++++++++++++ elasticsearch_dsl/document.py | 217 +++------------------ setup.py | 2 +- test_elasticsearch_dsl/test_connections.py | 10 +- 4 files changed, 221 insertions(+), 196 deletions(-) create mode 100644 elasticsearch_dsl/_async/document.py diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py new file mode 100644 index 000000000..a315ccec8 --- /dev/null +++ b/elasticsearch_dsl/_async/document.py @@ -0,0 +1,188 @@ +from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl.document import Document + + +class AsyncDocument(Document): + @classmethod + async def get_async(cls, id, using=None, index=None, **kwargs): + """ + Asynchronously retrieves a single document from elasticsearch using its ``id``. + + :arg id: ``id`` of the document to be retrieved + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.get`` unchanged. + """ + es = cls._get_connection(using) + ensure_async_connection(es, "Document.get_async") + + doc = await es.get(index=cls._default_index(index), id=id, **kwargs) + if not doc.get("found", False): + return None + return cls.from_es(doc) + + @classmethod + async def init_async(cls, index=None, using=None): + """ + Asynchronously creates the index and populates the mappings in Elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + await i.save(using=using) + + @classmethod + async def mget_async( + cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs + ): + r""" + Asynchronously retrieves multiple documents by their ``id``\s. Returns a list + of instances in the same order as requested. + + :arg docs: list of ``id``\s of the documents to be retrieved or a list + of document specifications as per + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg missing: what to do when one of the documents requested is not + found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise + ``NotFoundError``) or ``'skip'`` (ignore the missing document). + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.mget`` unchanged. + """ + if missing not in ("raise", "skip", "none"): + raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") + + es = cls._get_connection(using) + ensure_async_connection(es, "Document.mget_async") + + results = await es.mget( + cls._build_mget_body(docs), + index=cls._default_index(index), + **kwargs, + ) + + return cls._parse_mget_results( + results, + missing=missing, + raise_on_error=raise_on_error, + ) + + async def delete_async(self, using=None, index=None, **kwargs): + """ + Asynchronously deletes the instance in Elasticsearch. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.delete`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Document.delete_async") + doc_meta = self._build_delete_doc_meta(**kwargs) + await es.delete(index=self._get_index(index), **doc_meta) + + async def save_async( + self, using=None, index=None, validate=True, skip_empty=True, **kwargs + ): + """ + Asyncrhonously saves the document into Elasticsearch. If the document doesn't + exist it is created, otherwise it is overwritten. Returns ``True`` if this + operation resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + + Any additional keyword arguments will be passed to + ``AsyncElasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + ensure_async_connection(es, "Document.save_async") + + doc_meta = self._build_save_doc_meta(**kwargs) + meta = await es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta, + ) + self._update_doc_meta(meta) + + return meta["result"] + + async def update_async( + self, + using=None, + index=None, + detect_noop=True, + doc_as_upsert=False, + refresh=False, + retry_on_conflict=None, + script=None, + script_id=None, + scripted_upsert=False, + upsert=None, + **fields + ): + """ + Asynchronously performs a partial update of the document using the provided + fields. + + doc = MyDocument(title='Document Title!') + doc.save() + doc.update(title='New Document Title!') + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg detect_noop: Set to ``False`` to disable noop detection. + :arg refresh: Control when the changes made by this request are visible + to search. Set to ``True`` for immediate effect. + :arg retry_on_conflict: In between the get and indexing phases of the + update, it is possible that another process might have already + updated the same document. By default, the update will fail with a + version conflict exception. The retry_on_conflict parameter + controls how many times to retry the update before finally throwing + an exception. + :arg doc_as_upsert: Instead of sending a partial doc plus an upsert + doc, setting doc_as_upsert to true will use the contents of doc as + the upsert value + + :return operation result noop/updated + """ + body, doc_meta = self._build_update_body_and_meta( + detect_noop=detect_noop, + doc_as_upsert=doc_as_upsert, + retry_on_conflict=retry_on_conflict, + script=script, + script_id=script_id, + scripted_upsert=scripted_upsert, + upsert=upsert, + **fields, + ) + + es = self._get_connection(using) + ensure_async_connection(es, "Document.update_async") + + meta = await es.update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + self._update_doc_meta(meta) + + return meta["result"] diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index 0e0374d4b..cd14a8526 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import sys + try: import collections.abc as collections_abc # only works on python 3.3+ except ImportError: @@ -22,12 +24,9 @@ from fnmatch import fnmatch -from elasticsearch import AsyncElasticsearch from elasticsearch.exceptions import NotFoundError, RequestError from six import add_metaclass, iteritems -from ._async.search import AsyncSearch -from ._async.utils import ensure_async_connection from .connections import get_connection from .exceptions import IllegalOperation, ValidationException from .field import Field @@ -36,6 +35,19 @@ from .search import Search from .utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge +try: + from elasticsearch import AsyncElasticsearch + + from elasticsearch_dsl._async.search import AsyncSearch +except ImportError: + # Async is not support for one of two reasons: + # + # 1. The Python version is less than 3.6, so elasticsearch-py doesn't expose + # it's async features. + # 2. The aiohttp package isn't installed, so elasticsearch-py doesn't expose + # it's async features. + pass + class MetaField(object): def __init__(self, *args, **kwargs): @@ -191,8 +203,11 @@ def search(cls, using=None, index=None): "index": cls._default_index(index), } - if isinstance(es, AsyncElasticsearch): - return AsyncSearch(using=es, **kwargs) + try: + if isinstance(es, AsyncElasticsearch): + return AsyncSearch(using=es, **kwargs) + except NameError: + pass return Search(using=es, **kwargs) @@ -242,9 +257,7 @@ def mget( es = cls._get_connection(using) results = es.mget( - cls._build_mget_body(docs), - index=cls._default_index(index), - **kwargs, + cls._build_mget_body(docs), index=cls._default_index(index), **kwargs ) return cls._parse_mget_results( @@ -342,7 +355,7 @@ def update( script_id=script_id, scripted_upsert=scripted_upsert, upsert=upsert, - **fields, + **fields ) meta = self._get_connection(using).update( @@ -379,191 +392,7 @@ def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs) meta = es.index( index=self._get_index(index), body=self.to_dict(skip_empty=skip_empty), - **doc_meta, - ) - self._update_doc_meta(meta) - - return meta["result"] - - @classmethod - async def get_async(cls, id, using=None, index=None, **kwargs): - """ - Asynchronously retrieves a single document from elasticsearch using its ``id``. - - :arg id: ``id`` of the document to be retrieved - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.get`` unchanged. - """ - es = cls._get_connection(using) - ensure_async_connection(es, "Document.get_async") - - doc = await es.get(index=cls._default_index(index), id=id, **kwargs) - if not doc.get("found", False): - return None - return cls.from_es(doc) - - @classmethod - async def init_async(cls, index=None, using=None): - """ - Asynchronously creates the index and populates the mappings in Elasticsearch. - """ - i = cls._index - if index: - i = i.clone(name=index) - await i.save(using=using) - - @classmethod - async def mget_async( - cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs - ): - r""" - Asynchronously retrieves multiple documents by their ``id``\s. Returns a list - of instances in the same order as requested. - - :arg docs: list of ``id``\s of the documents to be retrieved or a list - of document specifications as per - https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-multi-get.html - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg missing: what to do when one of the documents requested is not - found. Valid options are ``'none'`` (use ``None``), ``'raise'`` (raise - ``NotFoundError``) or ``'skip'`` (ignore the missing document). - - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.mget`` unchanged. - """ - if missing not in ("raise", "skip", "none"): - raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - - es = cls._get_connection(using) - ensure_async_connection(es, "Document.mget_async") - - results = await es.mget( - cls._build_mget_body(docs), - index=cls._default_index(index), - **kwargs, - ) - - return cls._parse_mget_results( - results, - missing=missing, - raise_on_error=raise_on_error, - ) - - async def delete_async(self, using=None, index=None, **kwargs): - """ - Asynchronously deletes the instance in Elasticsearch. - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.delete`` unchanged. - """ - es = self._get_connection(using) - ensure_async_connection(es, "Document.delete_async") - doc_meta = self._build_delete_doc_meta(**kwargs) - await es.delete(index=self._get_index(index), **doc_meta) - - async def save_async( - self, using=None, index=None, validate=True, skip_empty=True, **kwargs - ): - """ - Asyncrhonously saves the document into Elasticsearch. If the document doesn't - exist it is created, otherwise it is overwritten. Returns ``True`` if this - operation resulted in new document being created. - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg validate: set to ``False`` to skip validating the document - :arg skip_empty: if set to ``False`` will cause empty values (``None``, - ``[]``, ``{}``) to be left on the document. Those values will be - stripped out otherwise as they make no difference in elasticsearch. - - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.index`` unchanged. - - :return operation result created/updated - """ - if validate: - self.full_clean() - - es = self._get_connection(using) - ensure_async_connection(es, "Document.save_async") - - doc_meta = self._build_save_doc_meta(**kwargs) - meta = await es.index( - index=self._get_index(index), - body=self.to_dict(skip_empty=skip_empty), - **doc_meta, - ) - self._update_doc_meta(meta) - - return meta["result"] - - async def update_async( - self, - using=None, - index=None, - detect_noop=True, - doc_as_upsert=False, - refresh=False, - retry_on_conflict=None, - script=None, - script_id=None, - scripted_upsert=False, - upsert=None, - **fields - ): - """ - Asynchronously performs a partial update of the document using the provided - fields. - - doc = MyDocument(title='Document Title!') - doc.save() - doc.update(title='New Document Title!') - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg detect_noop: Set to ``False`` to disable noop detection. - :arg refresh: Control when the changes made by this request are visible - to search. Set to ``True`` for immediate effect. - :arg retry_on_conflict: In between the get and indexing phases of the - update, it is possible that another process might have already - updated the same document. By default, the update will fail with a - version conflict exception. The retry_on_conflict parameter - controls how many times to retry the update before finally throwing - an exception. - :arg doc_as_upsert: Instead of sending a partial doc plus an upsert - doc, setting doc_as_upsert to true will use the contents of doc as - the upsert value - - :return operation result noop/updated - """ - body, doc_meta = self._build_update_body_and_meta( - detect_noop=detect_noop, - doc_as_upsert=doc_as_upsert, - retry_on_conflict=retry_on_conflict, - script=script, - script_id=script_id, - scripted_upsert=scripted_upsert, - upsert=upsert, - **fields, - ) - - es = self._get_connection(using) - ensure_async_connection(es, "Document.update_async") - - meta = await es.update( - index=self._get_index(index), body=body, refresh=refresh, **doc_meta + **doc_meta ) self._update_doc_meta(meta) diff --git a/setup.py b/setup.py index 064d90222..92fc7e73e 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ ] async_requires = [ - "aiohttp>=3,<4", + 'aiohttp>=3,<4; python_version>="3.6"', ] setup( diff --git a/test_elasticsearch_dsl/test_connections.py b/test_elasticsearch_dsl/test_connections.py index f91471c15..1db54bbde 100644 --- a/test_elasticsearch_dsl/test_connections.py +++ b/test_elasticsearch_dsl/test_connections.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch import AsyncElasticsearch, Elasticsearch +import sys + +import pytest +from elasticsearch import Elasticsearch from pytest import raises from elasticsearch_dsl import connections, serializer @@ -85,7 +88,12 @@ def test_create_connection_constructs_client(): assert [{"host": "es.com"}] == con.transport.hosts +@pytest.mark.skipif( + sys.version_info < (3, 6), reason="Async features require Python 3.6 or higher" +) def test_create_connection_constructs_async_client(): + from elasticsearch import AsyncElasticsearch + c = connections.Connections() c.create_connection("testing", client=AsyncElasticsearch, hosts=["es.com"]) From 00784c9513536ba34044c8d225f41ae6e8049c67 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 9 Oct 2020 07:50:22 -0600 Subject: [PATCH 06/11] Fix flake8 errors --- elasticsearch_dsl/__init__.py | 9 +++++++++ elasticsearch_dsl/document.py | 2 -- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index f1d9454ad..74a3c59c8 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -15,6 +15,13 @@ # specific language governing permissions and limitations # under the License. +# flake8: noqa:F401 +# +# The dynamic sorting of `__all__` at the bottom of the file breaks flake8 +# because flake8 is a static analysis tool. The alternative to ignoring these +# "unused import" errors would be to duplicate `__all__` statically after the +# async files are imported. + import sys from . import connections @@ -168,6 +175,7 @@ if sys.version_info < (3, 6): raise ImportError + from elasticsearch_dsl._async.document import AsyncDocument from elasticsearch_dsl._async.faceted_search import AsyncFacetedSearch from elasticsearch_dsl._async.index import AsyncIndex, AsyncIndexTemplate from elasticsearch_dsl._async.mapping import AsyncMapping @@ -177,6 +185,7 @@ __all__ = sorted( __all__ + [ + "AsyncDocument", "AsyncFacetedSearch", "AsyncIndex", "AsyncIndexTemplate", diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index cd14a8526..aa7aba1f4 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys - try: import collections.abc as collections_abc # only works on python 3.3+ except ImportError: From 17a134278f23bd04345f31d59af825de06db8755 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 9 Oct 2020 07:56:13 -0600 Subject: [PATCH 07/11] Add license headers to _async directory files --- elasticsearch_dsl/_async/document.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/faceted_search.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/index.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/mapping.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/search.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/update_by_query.py | 17 +++++++++++++++++ elasticsearch_dsl/_async/utils.py | 17 +++++++++++++++++ 7 files changed, 119 insertions(+) diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py index a315ccec8..f2821eb5b 100644 --- a/elasticsearch_dsl/_async/document.py +++ b/elasticsearch_dsl/_async/document.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.document import Document diff --git a/elasticsearch_dsl/_async/faceted_search.py b/elasticsearch_dsl/_async/faceted_search.py index caaaa8b91..6e5d4bcf6 100644 --- a/elasticsearch_dsl/_async/faceted_search.py +++ b/elasticsearch_dsl/_async/faceted_search.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch_dsl.faceted_search import FacetedSearch diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index 880437116..f99517db5 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection from elasticsearch_dsl.exceptions import IllegalOperation diff --git a/elasticsearch_dsl/_async/mapping.py b/elasticsearch_dsl/_async/mapping.py index a9ce2f1bb..df9fc7818 100644 --- a/elasticsearch_dsl/_async/mapping.py +++ b/elasticsearch_dsl/_async/mapping.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection from elasticsearch_dsl.mapping import Mapping diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index d054460aa..b5c1cc170 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch._async.helpers import async_scan from elasticsearch_dsl._async.utils import ensure_async_connection diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py index b52386b21..7a035e75a 100644 --- a/elasticsearch_dsl/_async/update_by_query.py +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection from elasticsearch_dsl.update_by_query import UpdateByQuery diff --git a/elasticsearch_dsl/_async/utils.py b/elasticsearch_dsl/_async/utils.py index e9abe4e1f..6b91de6c9 100644 --- a/elasticsearch_dsl/_async/utils.py +++ b/elasticsearch_dsl/_async/utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch import AsyncElasticsearch From eab87e0f4fe476cfeca8f2f57857dd9a3e296c28 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 9 Oct 2020 08:01:23 -0600 Subject: [PATCH 08/11] Move changes to __all__ out of try/except block --- elasticsearch_dsl/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index 74a3c59c8..36f2f3c53 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -181,7 +181,9 @@ from elasticsearch_dsl._async.mapping import AsyncMapping from elasticsearch_dsl._async.search import AsyncMultiSearch, AsyncSearch from elasticsearch_dsl._async.update_by_query import AsyncUpdateByQuery - +except (ImportError, SyntaxError): + pass +else: __all__ = sorted( __all__ + [ @@ -195,6 +197,3 @@ "AsyncUpdateByQuery", ] ) - -except (ImportError, SyntaxError): - pass From 114d1b415f8c29502f168bac048267e5cbab5458 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 23 Oct 2020 09:31:42 -0600 Subject: [PATCH 09/11] Refactor to use unasync package --- elasticsearch_dsl/__init__.py | 38 - elasticsearch_dsl/_async/__init__.py | 0 elasticsearch_dsl/_async/document.py | 430 +++++++++-- elasticsearch_dsl/_async/faceted_search.py | 404 +++++++++- elasticsearch_dsl/_async/index.py | 728 +++++++++++------- elasticsearch_dsl/_async/mapping.py | 212 +++++- elasticsearch_dsl/_async/search.py | 778 +++++++++++++++++++- elasticsearch_dsl/_async/update_by_query.py | 135 +++- elasticsearch_dsl/document.py | 280 +++---- elasticsearch_dsl/faceted_search.py | 10 +- elasticsearch_dsl/index.py | 213 ++++-- elasticsearch_dsl/mapping.py | 11 +- elasticsearch_dsl/search.py | 64 +- elasticsearch_dsl/update_by_query.py | 15 +- elasticsearch_dsl/utils.py | 9 + setup.py | 4 +- utils/generate-sync.py | 42 ++ 17 files changed, 2690 insertions(+), 683 deletions(-) create mode 100644 elasticsearch_dsl/_async/__init__.py create mode 100644 utils/generate-sync.py diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index 36f2f3c53..facddf7bd 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -15,15 +15,6 @@ # specific language governing permissions and limitations # under the License. -# flake8: noqa:F401 -# -# The dynamic sorting of `__all__` at the bottom of the file breaks flake8 -# because flake8 is a static analysis tool. The alternative to ignoring these -# "unused import" errors would be to duplicate `__all__` statically after the -# async files are imported. - -import sys - from . import connections from .aggs import A from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer @@ -168,32 +159,3 @@ "token_filter", "tokenizer", ] - - -try: - # Asyncio only supported in Python 3.6+ - if sys.version_info < (3, 6): - raise ImportError - - from elasticsearch_dsl._async.document import AsyncDocument - from elasticsearch_dsl._async.faceted_search import AsyncFacetedSearch - from elasticsearch_dsl._async.index import AsyncIndex, AsyncIndexTemplate - from elasticsearch_dsl._async.mapping import AsyncMapping - from elasticsearch_dsl._async.search import AsyncMultiSearch, AsyncSearch - from elasticsearch_dsl._async.update_by_query import AsyncUpdateByQuery -except (ImportError, SyntaxError): - pass -else: - __all__ = sorted( - __all__ - + [ - "AsyncDocument", - "AsyncFacetedSearch", - "AsyncIndex", - "AsyncIndexTemplate", - "AsyncMapping", - "AsyncMultiSearch", - "AsyncSearch", - "AsyncUpdateByQuery", - ] - ) diff --git a/elasticsearch_dsl/_async/__init__.py b/elasticsearch_dsl/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py index f2821eb5b..77cb0d68a 100644 --- a/elasticsearch_dsl/_async/document.py +++ b/elasticsearch_dsl/_async/document.py @@ -15,15 +15,183 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch_dsl._async.utils import ensure_async_connection -from elasticsearch_dsl.document import Document +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc +from fnmatch import fnmatch + +from elasticsearch.exceptions import NotFoundError, RequestError +from six import add_metaclass, iteritems + +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException +from elasticsearch_dsl.field import Field +from elasticsearch_dsl.index import Index +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge + +from .utils import ensure_async_connection + + +class MetaField(object): + def __init__(self, *args, **kwargs): + self.args, self.kwargs = args, kwargs + + +class DocumentMeta(type): + def __new__(cls, name, bases, attrs): + # DocumentMeta filters attrs in place + attrs["_doc_type"] = DocumentOptions(name, bases, attrs) + return super(DocumentMeta, cls).__new__(cls, name, bases, attrs) + + +class IndexMeta(DocumentMeta): + # global flag to guard us from associating an Index with the base Document + # class, only user defined subclasses should have an _index attr + _document_initialized = False + + def __new__(cls, name, bases, attrs): + new_cls = super(IndexMeta, cls).__new__(cls, name, bases, attrs) + if cls._document_initialized: + index_opts = attrs.pop("Index", None) + index = cls.construct_index(index_opts, bases) + new_cls._index = index + index.document(new_cls) + cls._document_initialized = True + return new_cls + + @classmethod + def construct_index(cls, opts, bases): + if opts is None: + for b in bases: + if hasattr(b, "_index"): + return b._index + + # Set None as Index name so it will set _all while making the query + return Index(name=None) + + i = Index(getattr(opts, "name", "*"), using=getattr(opts, "using", "default")) + i.settings(**getattr(opts, "settings", {})) + i.aliases(**getattr(opts, "aliases", {})) + for a in getattr(opts, "analyzers", ()): + i.analyzer(a) + return i + + +class DocumentOptions(object): + def __init__(self, name, bases, attrs): + meta = attrs.pop("Meta", None) + + # create the mapping instance + self.mapping = getattr(meta, "mapping", Mapping()) + + # register all declared fields into the mapping + for name, value in list(iteritems(attrs)): + if isinstance(value, Field): + self.mapping.field(name, value) + del attrs[name] + + # add all the mappings for meta fields + for name in dir(meta): + if isinstance(getattr(meta, name, None), MetaField): + params = getattr(meta, name) + self.mapping.meta(name, *params.args, **params.kwargs) + + # document inheritance - include the fields from parents' mappings + for b in bases: + if hasattr(b, "_doc_type") and hasattr(b._doc_type, "mapping"): + self.mapping.update(b._doc_type.mapping, update_only=True) + + @property + def name(self): + return self.mapping.properties.name + + +@add_metaclass(DocumentMeta) +class InnerDoc(ObjectBase): + """ + Common class for inner documents like Object or Nested + """ + + @classmethod + def from_es(cls, data, data_only=False): + if data_only: + data = {"_source": data} + return super(InnerDoc, cls).from_es(data) + + +@add_metaclass(IndexMeta) +class Document(ObjectBase): + """ + Model-like class for persisting documents in elasticsearch. + """ + + @classmethod + def _matches(cls, hit): + if cls._index._name is None: + return True + return fnmatch(hit.get("_index", ""), cls._index._name) + + @classmethod + def _get_using(cls, using=None): + return using or cls._index._using -class AsyncDocument(Document): @classmethod - async def get_async(cls, id, using=None, index=None, **kwargs): + def _get_connection(cls, using=None): + return get_connection(cls._get_using(using)) + + @classmethod + def _default_index(cls, index=None): + return index or cls._index._name + + @classmethod + async def init(cls, index=None, using=None): + """ + Create the index and populate the mappings in elasticsearch. + """ + i = cls._index + if index: + i = i.clone(name=index) + await i.save(using=using) + + def _get_index(self, index=None, required=True): + if index is None: + index = getattr(self.meta, "index", None) + if index is None: + index = getattr(self._index, "_name", None) + if index is None and required: + raise ValidationException("No index") + if index and "*" in index: + raise ValidationException("You cannot write to a wildcard index.") + return index + + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, + ", ".join( + "{}={!r}".format(key, getattr(self.meta, key)) + for key in ("index", "id") + if key in self.meta + ), + ) + + @classmethod + def search(cls, using=None, index=None): """ - Asynchronously retrieves a single document from elasticsearch using its ``id``. + Create an :class:`~elasticsearch_dsl.Search` instance that will search + over this ``Document``. + """ + return Search( + using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] + ) + + @classmethod + async def get(cls, id, using=None, index=None, **kwargs): + """ + Retrieve a single document from elasticsearch using its ``id``. :arg id: ``id`` of the document to be retrieved :arg index: elasticsearch index to use, if the ``Document`` is @@ -31,10 +199,10 @@ async def get_async(cls, id, using=None, index=None, **kwargs): :arg using: connection alias to use, defaults to ``'default'`` Any additional keyword arguments will be passed to - ``AsyncElasticsearch.get`` unchanged. + ``Elasticsearch.get`` unchanged. """ es = cls._get_connection(using) - ensure_async_connection(es, "Document.get_async") + ensure_async_connection(es, "Document.get") doc = await es.get(index=cls._default_index(index), id=id, **kwargs) if not doc.get("found", False): @@ -42,22 +210,12 @@ async def get_async(cls, id, using=None, index=None, **kwargs): return cls.from_es(doc) @classmethod - async def init_async(cls, index=None, using=None): - """ - Asynchronously creates the index and populates the mappings in Elasticsearch. - """ - i = cls._index - if index: - i = i.clone(name=index) - await i.save(using=using) - - @classmethod - async def mget_async( + async def mget( cls, docs, using=None, index=None, raise_on_error=True, missing="none", **kwargs ): r""" - Asynchronously retrieves multiple documents by their ``id``\s. Returns a list - of instances in the same order as requested. + Retrieve multiple document by their ``id``\s. Returns a list of instances + in the same order as requested. :arg docs: list of ``id``\s of the documents to be retrieved or a list of document specifications as per @@ -70,80 +228,106 @@ async def mget_async( ``NotFoundError``) or ``'skip'`` (ignore the missing document). Any additional keyword arguments will be passed to - ``AsyncElasticsearch.mget`` unchanged. + ``Elasticsearch.mget`` unchanged. """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - + es = cls._get_connection(using) - ensure_async_connection(es, "Document.mget_async") + ensure_async_connection(es, "Document.mget") - results = await es.mget( - cls._build_mget_body(docs), - index=cls._default_index(index), - **kwargs, - ) + body = { + "docs": [ + doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} + for doc in docs + ] + } + results = await es.mget(body, index=cls._default_index(index), **kwargs) - return cls._parse_mget_results( - results, - missing=missing, - raise_on_error=raise_on_error, - ) + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue + + objs.append(cls.from_es(doc)) - async def delete_async(self, using=None, index=None, **kwargs): + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) + + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) + + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = "Documents %s not found." % ", ".join(missing_ids) + raise NotFoundError(404, message, {"docs": missing_docs}) + return objs + + async def delete(self, using=None, index=None, **kwargs): """ - Asynchronously deletes the instance in Elasticsearch. + Delete the instance in elasticsearch. :arg index: elasticsearch index to use, if the ``Document`` is associated with an index this can be omitted. :arg using: connection alias to use, defaults to ``'default'`` Any additional keyword arguments will be passed to - ``AsyncElasticsearch.delete`` unchanged. + ``Elasticsearch.delete`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "Document.delete_async") - doc_meta = self._build_delete_doc_meta(**kwargs) + ensure_async_connection(es, "Document.delete") + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) await es.delete(index=self._get_index(index), **doc_meta) - async def save_async( - self, using=None, index=None, validate=True, skip_empty=True, **kwargs - ): + def to_dict(self, include_meta=False, skip_empty=True): """ - Asyncrhonously saves the document into Elasticsearch. If the document doesn't - exist it is created, otherwise it is overwritten. Returns ``True`` if this - operation resulted in new document being created. + Serialize the instance into a dictionary so that it can be saved in elasticsearch. - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg validate: set to ``False`` to skip validating the document + :arg include_meta: if set to ``True`` will include all the metadata + (``_index``, ``_id`` etc). Otherwise just the document's + data is serialized. This is useful when passing multiple instances into + ``elasticsearch.helpers.bulk``. :arg skip_empty: if set to ``False`` will cause empty values (``None``, ``[]``, ``{}``) to be left on the document. Those values will be stripped out otherwise as they make no difference in elasticsearch. - - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.index`` unchanged. - - :return operation result created/updated """ - if validate: - self.full_clean() + d = super(Document, self).to_dict(skip_empty=skip_empty) + if not include_meta: + return d - es = self._get_connection(using) - ensure_async_connection(es, "Document.save_async") + meta = {"_" + k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - doc_meta = self._build_save_doc_meta(**kwargs) - meta = await es.index( - index=self._get_index(index), - body=self.to_dict(skip_empty=skip_empty), - **doc_meta, - ) - self._update_doc_meta(meta) + # in case of to_dict include the index unlike save/update/delete + index = self._get_index(required=False) + if index is not None: + meta["_index"] = index - return meta["result"] + meta["_source"] = d + return meta - async def update_async( + async def update( self, using=None, index=None, @@ -158,8 +342,8 @@ async def update_async( **fields ): """ - Asynchronously performs a partial update of the document using the provided - fields. + Partial update of the document, specify fields you wish to update and + both the instance and the document in elasticsearch will be updated:: doc = MyDocument(title='Document Title!') doc.save() @@ -183,23 +367,109 @@ async def update_async( :return operation result noop/updated """ - body, doc_meta = self._build_update_body_and_meta( - detect_noop=detect_noop, - doc_as_upsert=doc_as_upsert, - retry_on_conflict=retry_on_conflict, - script=script, - script_id=script_id, - scripted_upsert=scripted_upsert, - upsert=upsert, - **fields, - ) + body = { + "doc_as_upsert": doc_as_upsert, + "detect_noop": detect_noop, + } + + # scripted update + if script or script_id: + if upsert is not None: + body["upsert"] = upsert + + if script: + script = {"source": script} + else: + script = {"id": script_id} + + script["params"] = fields + + body["script"] = script + body["scripted_upsert"] = scripted_upsert + + # partial document update + else: + if not fields: + raise IllegalOperation( + "You cannot call update() without updating individual fields or a script. " + "If you wish to update the entire object use save()." + ) + + # update given fields locally + merge(self, fields) + + # prepare data for ES + values = self.to_dict() + + # if fields were given: partial update + body["doc"] = {k: values.get(k) for k in fields.keys()} + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + if retry_on_conflict is not None: + doc_meta["retry_on_conflict"] = retry_on_conflict + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] es = self._get_connection(using) - ensure_async_connection(es, "Document.update_async") + ensure_async_connection(es, "Document.update") meta = await es.update( index=self._get_index(index), body=body, refresh=refresh, **doc_meta ) - self._update_doc_meta(meta) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) + + return meta["result"] + + async def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. + + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. + + :return operation result created/updated + """ + if validate: + self.full_clean() + + es = self._get_connection(using) + ensure_async_connection(es, "Document.save") + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + meta = await es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) return meta["result"] diff --git a/elasticsearch_dsl/_async/faceted_search.py b/elasticsearch_dsl/_async/faceted_search.py index 6e5d4bcf6..535613832 100644 --- a/elasticsearch_dsl/_async/faceted_search.py +++ b/elasticsearch_dsl/_async/faceted_search.py @@ -15,13 +15,411 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch_dsl.faceted_search import FacetedSearch +from datetime import datetime, timedelta +from six import iteritems, itervalues + +from elasticsearch_dsl.aggs import A +from elasticsearch_dsl.query import MatchAll, Nested, Range, Terms +from elasticsearch_dsl.response import Response +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import AttrDict + +__all__ = [ + "FacetedSearch", + "HistogramFacet", + "TermsFacet", + "DateHistogramFacet", + "RangeFacet", + "NestedFacet", +] + + +class Facet(object): + """ + A facet on faceted search. Wraps and aggregation and provides functionality + to create a filter for selected values and return a list of facet values + from the result of the aggregation. + """ + + agg_type = None + + def __init__(self, metric=None, metric_sort="desc", **kwargs): + self.filter_values = () + self._params = kwargs + self._metric = metric + if metric and metric_sort: + self._params["order"] = {"metric": metric_sort} + + def get_aggregation(self): + """ + Return the aggregation object. + """ + agg = A(self.agg_type, **self._params) + if self._metric: + agg.metric("metric", self._metric) + return agg + + def add_filter(self, filter_values): + """ + Construct a filter. + """ + if not filter_values: + return + + f = self.get_value_filter(filter_values[0]) + for v in filter_values[1:]: + f |= self.get_value_filter(v) + return f + + def get_value_filter(self, filter_value): + """ + Construct a filter for an individual value + """ + pass + + def is_filtered(self, key, filter_values): + """ + Is a filter active on the given key. + """ + return key in filter_values + + def get_value(self, bucket): + """ + return a value representing a bucket. Its key as default. + """ + return bucket["key"] + + def get_metric(self, bucket): + """ + Return a metric, by default doc_count for a bucket. + """ + if self._metric: + return bucket["metric"]["value"] + return bucket["doc_count"] + + def get_values(self, data, filter_values): + """ + Turn the raw bucket data into a list of tuples containing the key, + number of documents and a flag indicating whether this value has been + selected or not. + """ + out = [] + for bucket in data.buckets: + key = self.get_value(bucket) + out.append( + (key, self.get_metric(bucket), self.is_filtered(key, filter_values)) + ) + return out + + +class TermsFacet(Facet): + agg_type = "terms" + + def add_filter(self, filter_values): + """ Create a terms filter instead of bool containing term filters. """ + if filter_values: + return Terms( + _expand__to_dot=False, **{self._params["field"]: filter_values} + ) + + +class RangeFacet(Facet): + agg_type = "range" + + def _range_to_dict(self, range): + key, range = range + out = {"key": key} + if range[0] is not None: + out["from"] = range[0] + if range[1] is not None: + out["to"] = range[1] + return out + + def __init__(self, ranges, **kwargs): + super(RangeFacet, self).__init__(**kwargs) + self._params["ranges"] = list(map(self._range_to_dict, ranges)) + self._params["keyed"] = False + self._ranges = dict(ranges) + + def get_value_filter(self, filter_value): + f, t = self._ranges[filter_value] + limits = {} + if f is not None: + limits["gte"] = f + if t is not None: + limits["lt"] = t + + return Range(_expand__to_dot=False, **{self._params["field"]: limits}) + + +class HistogramFacet(Facet): + agg_type = "histogram" + + def get_value_filter(self, filter_value): + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": filter_value + self._params["interval"], + } + } + ) + + +class DateHistogramFacet(Facet): + agg_type = "date_histogram" + + DATE_INTERVALS = { + "month": lambda d: (d + timedelta(days=32)).replace(day=1), + "week": lambda d: d + timedelta(days=7), + "day": lambda d: d + timedelta(days=1), + "hour": lambda d: d + timedelta(hours=1), + } + + def __init__(self, **kwargs): + kwargs.setdefault("min_doc_count", 0) + super(DateHistogramFacet, self).__init__(**kwargs) + + def get_value(self, bucket): + if not isinstance(bucket["key"], datetime): + # Elasticsearch returns key=None instead of 0 for date 1970-01-01, + # so we need to set key to 0 to avoid TypeError exception + if bucket["key"] is None: + bucket["key"] = 0 + # Preserve milliseconds in the datetime + return datetime.utcfromtimestamp(int(bucket["key"]) / 1000.0) + else: + return bucket["key"] + + def get_value_filter(self, filter_value): + return Range( + _expand__to_dot=False, + **{ + self._params["field"]: { + "gte": filter_value, + "lt": self.DATE_INTERVALS[self._params["interval"]](filter_value), + } + } + ) + + +class NestedFacet(Facet): + agg_type = "nested" + + def __init__(self, path, nested_facet): + self._path = path + self._inner = nested_facet + super(NestedFacet, self).__init__( + path=path, aggs={"inner": nested_facet.get_aggregation()} + ) + + def get_values(self, data, filter_values): + return self._inner.get_values(data.inner, filter_values) + + def add_filter(self, filter_values): + inner_q = self._inner.add_filter(filter_values) + if inner_q: + return Nested(path=self._path, query=inner_q) + + +class FacetedResponse(Response): + @property + def query_string(self): + return self._faceted_search._query + + @property + def facets(self): + if not hasattr(self, "_facets"): + super(AttrDict, self).__setattr__("_facets", AttrDict({})) + for name, facet in iteritems(self._faceted_search.facets): + self._facets[name] = facet.get_values( + getattr(getattr(self.aggregations, "_filter_" + name), name), + self._faceted_search.filter_values.get(name, ()), + ) + return self._facets + + +class FacetedSearch(object): + """ + Abstraction for creating faceted navigation searches that takes care of + composing the queries, aggregations and filters as needed as well as + presenting the results in an easy-to-consume fashion:: + + class BlogSearch(FacetedSearch): + index = 'blogs' + doc_types = [Blog, Post] + fields = ['title^5', 'category', 'description', 'body'] + + facets = { + 'type': TermsFacet(field='_type'), + 'category': TermsFacet(field='category'), + 'weekly_posts': DateHistogramFacet(field='published_from', interval='week') + } + + def search(self): + ' Override search to add your own filters ' + s = super(BlogSearch, self).search() + return s.filter('term', published=True) + + # when using: + blog_search = BlogSearch("web framework", filters={"category": "python"}) + + # supports pagination + blog_search[10:20] + + response = blog_search.execute() + + # easy access to aggregation results: + for category, hit_count, is_selected in response.facets.category: + print( + "Category %s has %d hits%s." % ( + category, + hit_count, + ' and is chosen' if is_selected else '' + ) + ) + + """ + + index = None + doc_types = None + fields = None + facets = {} + using = "default" + + def __init__(self, query=None, filters={}, sort=()): + """ + :arg query: the text to search for + :arg filters: facet values to filter + :arg sort: sort information to be passed to :class:`~elasticsearch_dsl.Search` + """ + self._query = query + self._filters = {} + self._sort = sort + self.filter_values = {} + for name, value in iteritems(filters): + self.add_filter(name, value) + + self._s = self.build_search() + + def count(self): + return self._s.count() + + def __getitem__(self, k): + self._s = self._s[k] + return self + + def __iter__(self): + return iter(self._s) + + def add_filter(self, name, filter_values): + """ + Add a filter for a facet. + """ + # normalize the value into a list + if not isinstance(filter_values, (tuple, list)): + if filter_values is None: + return + filter_values = [ + filter_values, + ] + + # remember the filter values for use in FacetedResponse + self.filter_values[name] = filter_values + + # get the filter from the facet + f = self.facets[name].add_filter(filter_values) + if f is None: + return + + self._filters[name] = f + + def search(self): + """ + Returns the base Search object to which the facets are added. + + You can customize the query by overriding this method and returning a + modified search object. + """ + s = Search(doc_type=self.doc_types, index=self.index, using=self.using) + return s.response_class(FacetedResponse) + + def query(self, search, query): + """ + Add query part to ``search``. + + Override this if you wish to customize the query used. + """ + if query: + if self.fields: + return search.query("multi_match", fields=self.fields, query=query) + else: + return search.query("multi_match", query=query) + return search + + def aggregate(self, search): + """ + Add aggregations representing the facets selected, including potential + filters. + """ + for f, facet in iteritems(self.facets): + agg = facet.get_aggregation() + agg_filter = MatchAll() + for field, filter in iteritems(self._filters): + if f == field: + continue + agg_filter &= filter + search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket( + f, agg + ) + + def filter(self, search): + """ + Add a ``post_filter`` to the search request narrowing the results based + on the facet filters. + """ + if not self._filters: + return search + + post_filter = MatchAll() + for f in itervalues(self._filters): + post_filter &= f + return search.post_filter(post_filter) + + def highlight(self, search): + """ + Add highlighting for all the fields + """ + return search.highlight( + *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields) + ) + + def sort(self, search): + """ + Add sorting information to the request. + """ + if self._sort: + search = search.sort(*self._sort) + return search + + def build_search(self): + """ + Construct the ``Search`` object. + """ + s = self.search() + s = self.query(s, self._query) + s = self.filter(s) + if self.fields: + s = self.highlight(s) + s = self.sort(s) + self.aggregate(s) + return s -class AsyncFacetedSearch(FacetedSearch): async def execute(self): """ - Asynchronously execute the search and return the response. + Execute the search and return the response. """ r = await self._s.execute() r._faceted_search = self diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index f99517db5..570dcdc6b 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -15,59 +15,274 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch_dsl._async.utils import ensure_async_connection +from elasticsearch_dsl import analysis from elasticsearch_dsl.connections import get_connection from elasticsearch_dsl.exceptions import IllegalOperation -from elasticsearch_dsl.index import Index, IndexTemplate +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.update_by_query import UpdateByQuery +from elasticsearch_dsl.utils import merge + +from .utils import ensure_async_connection + + +class IndexTemplate(object): + def __init__(self, name, template, index=None, order=None, **kwargs): + if index is None: + self._index = Index(template, **kwargs) + else: + if kwargs: + raise ValueError( + "You cannot specify options for Index when" + " passing an Index instance." + ) + self._index = index.clone() + self._index._name = template + self._template_name = name + self.order = order + + def __getattr__(self, attr_name): + return getattr(self._index, attr_name) + + def to_dict(self): + d = self._index.to_dict() + d["index_patterns"] = [self._index._name] + if self.order is not None: + d["order"] = self.order + return d + async def save(self, using=None): + es = get_connection(using or self._index._using) + ensure_async_connection(es, "IndexTemplate.save") -class AsyncIndex(Index): - async def analyze(self, using=None, **kwargs): + return await es.indices.put_template( + name=self._template_name, body=self.to_dict() + ) + + +class Index(object): + def __init__(self, name, using="default"): """ - Asynchronously perform the analysis process on a text and return the tokens - breakdown of the text. + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` + """ + self._name = name + self._doc_types = [] + self._using = using + self._settings = {} + self._aliases = {} + self._analysis = {} + self._mapping = None - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.analyze`` unchanged. + def get_or_create_mapping(self): + if self._mapping is None: + self._mapping = Mapping() + return self._mapping + + def as_template(self, template_name, pattern=None, order=None): + # TODO: should we allow pattern to be a top-level arg? + # or maybe have an IndexPattern that allows for it and have + # Document._index be that? + return IndexTemplate( + template_name, pattern or self._name, index=self, order=order + ) + + def resolve_nested(self, field_path): + for doc in self._doc_types: + nested, field = doc._doc_type.mapping.resolve_nested(field_path) + if field is not None: + return nested, field + if self._mapping: + return self._mapping.resolve_nested(field_path) + return (), None + + def resolve_field(self, field_path): + for doc in self._doc_types: + field = doc._doc_type.mapping.resolve_field(field_path) + if field is not None: + return field + if self._mapping: + return self._mapping.resolve_field(field_path) + return None + + async def load_mappings(self, using=None): + mapping = self.get_or_create_mapping() + + await mapping.update_from_es(self._name, using=using or self._using) + + def clone(self, name=None, using=None): """ - es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.analyze") + Create a copy of the instance with another name or connection alias. + Useful for creating multiple indices with shared configuration:: - return await es.indices.analyze(index=self._index, **kwargs) + i = Index('base-index') + i.settings(number_of_shards=1) + i.create() - async def clear_cache(self, using=None, **kwargs): + i2 = i.clone('other-index') + i2.create() + + :arg name: name of the index + :arg using: connection alias to use, defaults to ``'default'`` """ - Asynchronously clear all caches or specific cached associated with the index. + i = Index(name or self._name, using=using or self._using) + i._settings = self._settings.copy() + i._aliases = self._aliases.copy() + i._analysis = self._analysis.copy() + i._doc_types = self._doc_types[:] + if self._mapping is not None: + i._mapping = self._mapping._clone() + return i - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.clear_cache`` unchanged. + def _get_connection(self, using=None): + if self._name is None: + raise ValueError("You cannot perform API calls on the default index.") + return get_connection(using or self._using) + + connection = property(_get_connection) + + def mapping(self, mapping): """ - es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.clear_cache") + Associate a mapping (an instance of + :class:`~elasticsearch_dsl.Mapping`) with this index. + This means that, when this index is created, it will contain the + mappings for the document type defined by those mappings. + """ + self.get_or_create_mapping().update(mapping) - return await es.indices.clear_cache(index=self._index, **kwargs) + def document(self, document): + """ + Associate a :class:`~elasticsearch_dsl.Document` subclass with an index. + This means that, when this index is created, it will contain the + mappings for the ``Document``. If the ``Document`` class doesn't have a + default index yet (by defining ``class Index``), this instance will be + used. Can be used as a decorator:: - async def close(self, using=None, **kwargs): + i = Index('blog') + + @i.document + class Post(Document): + title = Text() + + # create the index, including Post mappings + i.create() + + # .search() will now return a Search object that will return + # properly deserialized Post instances + s = i.search() """ - Asynchronously closes the index in Elasticsearch. + self._doc_types.append(document) - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.close`` unchanged. + # If the document index does not have any name, that means the user + # did not set any index already to the document. + # So set this index as document index + if document._index._name is None: + document._index = self + + return document + + def settings(self, **kwargs): """ - es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.close") + Add settings to the index:: - return await es.indices.close(index=self._index, **kwargs) + i = Index('i') + i.settings(number_of_shards=1, number_of_replicas=0) + + Multiple calls to ``settings`` will merge the keys, later overriding + the earlier. + """ + self._settings.update(kwargs) + return self + + def aliases(self, **kwargs): + """ + Add aliases to the index definition:: + + i = Index('blog-v2') + i.aliases(blog={}, published={'filter': Q('term', published=True)}) + """ + self._aliases.update(kwargs) + return self + + def analyzer(self, *args, **kwargs): + """ + Explicitly add an analyzer to an index. Note that all custom analyzers + defined in mappings will also be created. This is useful for search analyzers. + + Example:: + + from elasticsearch_dsl import analyzer, tokenizer + + my_analyzer = analyzer('my_analyzer', + tokenizer=tokenizer('trigram', 'nGram', min_gram=3, max_gram=3), + filter=['lowercase'] + ) + + i = Index('blog') + i.analyzer(my_analyzer) + + """ + analyzer = analysis.analyzer(*args, **kwargs) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + return + + # merge the definition + merge(self._analysis, d, True) + + def to_dict(self): + out = {} + if self._settings: + out["settings"] = self._settings + if self._aliases: + out["aliases"] = self._aliases + mappings = self._mapping.to_dict() if self._mapping else {} + analysis = self._mapping._collect_analysis() if self._mapping else {} + for d in self._doc_types: + mapping = d._doc_type.mapping + merge(mappings, mapping.to_dict(), True) + merge(analysis, mapping._collect_analysis(), True) + if mappings: + out["mappings"] = mappings + if analysis or self._analysis: + merge(analysis, self._analysis) + out.setdefault("settings", {})["analysis"] = analysis + return out + + def search(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.Search` object searching over the + index (or all the indices belonging to this template) and its + ``Document``\\s. + """ + return Search( + using=using or self._using, index=self._name, doc_type=self._doc_types + ) + + def updateByQuery(self, using=None): + """ + Return a :class:`~elasticsearch_dsl.UpdateByQuery` object searching over the index + (or all the indices belonging to this template) and updating Documents that match + the search criteria. + + For more information, see here: + https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-update-by-query.html + """ + return UpdateByQuery( + using=using or self._using, + index=self._name, + ) async def create(self, using=None, **kwargs): """ - Asynchronously creates the index in Elasticsearch. + Creates the index in elasticsearch. Any additional keyword arguments will be passed to ``Elasticsearch.indices.create`` unchanged. """ es = get_connection(using) - ensure_async_connection(es, "AsyncIndex.create") + ensure_async_connection(es, "Index.create") return await es.indices.create( index=self._name, @@ -75,415 +290,436 @@ async def create(self, using=None, **kwargs): **kwargs, ) - async def delete(self, using=None, **kwargs): + async def is_closed(self, using=None): + es = get_connection(using) + ensure_async_connection(es, "Index.is_closed") + + state = await es.cluster.state( + index=self._name, + metric="metadata", + ) + + return state["metadata"]["indices"][self._name]["state"] == "close" + + async def save(self, using=None): """ - Asynchronously deletes the index in Elasticsearch. + Sync the index definition with elasticsearch, creating the index if it + doesn't exist and updating its settings and mappings if it does. - Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.delete`` unchanged. + Note some settings and mapping changes cannot be done on an open + index (or at all on an existing index) and for those this method will + fail with the underlying exception. """ - es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.delete") + if not await self.exists(using=using): + return await self.create(using=using) - return await es.indices.delete(index=self._index, **kwargs) + body = self.to_dict() + settings = body.pop("settings", {}) + analysis = settings.pop("analysis", None) + current_settings = self.get_settings(using=using)[self._name]["settings"][ + "index" + ] + if analysis: + if await self.is_closed(using=using): + # closed index, update away + settings["analysis"] = analysis + else: + # compare analysis definition, if all analysis objects are + # already defined as requested, skip analysis update and + # proceed, otherwise raise IllegalOperation + existing_analysis = current_settings.get("analysis", {}) + if any( + existing_analysis.get(section, {}).get(k, None) + != analysis[section][k] + for section in analysis + for k in analysis[section] + ): + raise IllegalOperation( + "You cannot update analysis configuration on an open index, " + "you need to close index %s first." % self._name + ) - async def delete_alias(self, using=None, **kwargs): + # try and update the settings + if settings: + settings = settings.copy() + for k, v in list(settings.items()): + if k in current_settings and current_settings[k] == str(v): + del settings[k] + + if settings: + await self.put_settings(using=using, body=settings) + + # update the mappings, any conflict in the mappings will result in an + # exception + mappings = body.pop("mappings", {}) + if mappings: + await self.put_mapping(using=using, body=mappings) + + async def analyze(self, using=None, **kwargs): """ - Asynchronously deletes a specific alias. + Perform the analysis process on a text and return the tokens breakdown + of the text. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.delete_alias`` unchanged. + ``Elasticsearch.indices.analyze`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.delete_alias") + ensure_async_connection(es, "Index.analyze") - return await es.indices.delete_alias(index=self._index, **kwargs) + return await es.indices.analyze(index=self._index, **kwargs) - async def exists(self, using=None, **kwargs): + async def refresh(self, using=None, **kwargs): """ - Asynchronously queries Elasticsearch for whether this index exists. Returns - ``True`` if the index already exists in Elasticsearch, otherwise ``False``. + Performs a refresh operation on the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.exists`` unchanged. + ``Elasticsearch.indices.refresh`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.exists") + ensure_async_connection(es, "Index.refresh") - return await es.indices.exists(index=self._index, **kwargs) + return await es.indices.refresh(index=self._index, **kwargs) - async def exists_type(self, using=None, **kwargs): + async def flush(self, using=None, **kwargs): """ - Asynchronously queries Elasticsearch for whether a type or set of types exists - in the index. Returns ``True`` if the type/types exist, otherwise ``False``. + Performs a flush operation on the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.exists_type`` unchanged. + ``Elasticsearch.indices.flush`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.exists_type") + ensure_async_connection(es, "Index.flush") - return await es.indices.exists_type(index=self._index, **kwargs) + return await es.indices.flush(index=self._index, **kwargs) - async def flush(self, using=None, **kwargs): + async def get(self, using=None, **kwargs): """ - Asynchronously performs a flush operation on the index. + The get index API allows to retrieve information about the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.flush`` unchanged. + ``Elasticsearch.indices.get`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.flush") + ensure_async_connection(es, "Index.get") - return await es.indices.flush(index=self._index, **kwargs) + return await es.indices.get(index=self._index, **kwargs) - async def flush_synced(self, using=None, **kwargs): + async def open(self, using=None, **kwargs): """ - Asynchronously performs a normal flush, then adds a unique marker (sync_id) to - all shards. + Opens the index in elasticsearch. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.flush_synced`` unchanged. + ``Elasticsearch.indices.open`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.flush_synced") + ensure_async_connection(es, "Index.open") - return await es.indices.flush_synced(index=self._index, **kwargs) + return await es.indices.open(index=self._index, **kwargs) - async def forcemerge(self, using=None, **kwargs): + async def close(self, using=None, **kwargs): + """ + Closes the index in elasticsearch. + + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.close`` unchanged. """ - Asynchronously calls the force merge API. + es = self._get_connection(using) + ensure_async_connection(es, "Index.close") - The force merge API allows to force merging of the index through an API. The - merge relates to the number of segments a Lucene index holds within each shard. - The force merge operation allows to reduce the number of segments by merging - them. + return await es.indices.close(index=self._index, **kwargs) - This call will block until the merge is complete. If the http connection is - lost, the request will continue in the background, and any new requests will - block until the previous force merge is complete. + async def delete(self, using=None, **kwargs): + """ + Deletes the index in elasticsearch. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.forcemerge`` unchanged. + ``Elasticsearch.indices.delete`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.forcemerge") + ensure_async_connection(es, "Index.delete") - return await es.indices.forcemerge(index=self._index, **kwargs) + return await es.indices.delete(index=self._index, **kwargs) - async def get(self, using=None, **kwargs): + async def exists(self, using=None, **kwargs): """ - Asynchronously retrieves information about the index from Elasticsearch. + Returns ``True`` if the index already exists in elasticsearch. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get`` unchanged. + ``Elasticsearch.indices.exists`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get") + ensure_async_connection(es, "Index.exists") - return await es.indices.get(index=self._index, **kwargs) + return await es.indices.exists(index=self._index, **kwargs) - async def get_alias(self, using=None, **kwargs): + async def exists_type(self, using=None, **kwargs): """ - Asynchronously retrieves a specific alias. + Check if a type/types exists in the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get_alias`` unchanged. + ``Elasticsearch.indices.exists_type`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get_alias") + ensure_async_connection(es, "Index.exists_type") - return await es.indices.get_alias(index=self._index, **kwargs) + return await es.indices.exists_type(index=self._index, **kwargs) - async def get_field_mapping(self, using=None, **kwargs): + async def put_mapping(self, using=None, **kwargs): """ - Asynchronously retrieves a mapping definition for a specific field. + Register specific mapping definition for a specific type. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get_field_mapping`` unchanged. + ``Elasticsearch.indices.put_mapping`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get_field_mapping") + ensure_async_connection(es, "Index.put_mapping") - return await es.indices.get_field_mapping(index=self._index, **kwargs) + return await es.indices.put_mapping(index=self._index, **kwargs) async def get_mapping(self, using=None, **kwargs): """ - Asynchronously retrieves a specific mapping definition for a specific type. + Retrieve specific mapping definition for a specific type. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get_mapping`` unchanged. + ``Elasticsearch.indices.get_mapping`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get_mapping") + ensure_async_connection(es, "Index.get_mapping") return await es.indices.get_mapping(index=self._index, **kwargs) - async def get_settings(self, using=None, **kwargs): + async def get_field_mapping(self, using=None, **kwargs): """ - Asynchronously retrieves the settings for the index. + Retrieve mapping definition of a specific field. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get_settings`` unchanged. + ``Elasticsearch.indices.get_field_mapping`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get_settings") + ensure_async_connection(es, "Index.get_field_mapping") - return await es.indices.get_settings(index=self._index, **kwargs) + return await es.indices.get_field_mapping(index=self._index, **kwargs) - async def get_upgrade(self, using=None, **kwargs): + async def put_alias(self, using=None, **kwargs): """ - Asynchronously monitors how much of an index is upgraded. + Create an alias for the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.get_upgrade`` unchanged. + ``Elasticsearch.indices.put_alias`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.get_upgrade") + ensure_async_connection(es, "Index.put_alias") - return await es.indices.get_upgrade(index=self._index, **kwargs) + return await es.indices.put_alias(index=self._index, **kwargs) - async def is_closed(self, using=None): - """ - Asynchronously queries Elasticsearch to determine whether this index - is closed. + def exists_alias(self, using=None, **kwargs): """ - es = get_connection(using) - ensure_async_connection(es, "AsyncIndex.is_closed") + Return a boolean indicating whether given alias exists for this index. - state = await es.cluster.state( - index=self._name, - metric="metadata", + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.exists_alias`` unchanged. + """ + return self._get_connection(using).indices.exists_alias( + index=self._name, **kwargs ) - return state["metadata"]["indices"][self._name]["state"] == "close" - - async def load_mappings(self, using=None): - mapping = self.get_or_create_mapping() - - await mapping.update_from_es(self._name, using=using or self._using) - - async def open(self, using=None, **kwargs): + async def get_alias(self, using=None, **kwargs): """ - Asynchronously opens the index in Elasticsearch. + Retrieve a specified alias. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.open`` unchanged. + ``Elasticsearch.indices.get_alias`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.open") + ensure_async_connection(es, "Index.get_alias") - return await es.indices.open(index=self._index, **kwargs) + return await es.indices.get_alias(index=self._index, **kwargs) - async def put_alias(self, using=None, **kwargs): + async def delete_alias(self, using=None, **kwargs): """ - Asynchronously creates an alias for the index. + Delete specific alias. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.put_alias`` unchanged. + ``Elasticsearch.indices.delete_alias`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.put_alias") + ensure_async_connection(es, "Index.delete_alias") - return await es.indices.put_alias(index=self._index, **kwargs) + return await es.indices.delete_alias(index=self._index, **kwargs) - async def put_mapping(self, using=None, **kwargs): + async def get_settings(self, using=None, **kwargs): """ - Asynchronously register a specific mapping definition for a specific type. + Retrieve settings for the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.put_mapping`` unchanged. + ``Elasticsearch.indices.get_settings`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.put_mapping") + ensure_async_connection(es, "Index.get_settings") - return await es.indices.put_mapping(index=self._index, **kwargs) + return await es.indices.get_settings(index=self._index, **kwargs) async def put_settings(self, using=None, **kwargs): """ - Asynchronously changes specific index-level settings. + Change specific index level settings in real time. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.put_settings`` unchanged. + ``Elasticsearch.indices.put_settings`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.put_settings") + ensure_async_connection(es, "Index.put_settings") return await es.indices.put_settings(index=self._index, **kwargs) - async def recovery(self, using=None, **kwargs): + async def stats(self, using=None, **kwargs): """ - Asynchronously provides insight into ongoing shard recoveries for the index. + Retrieve statistics on different operations happening on the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.recovery`` unchanged. + ``Elasticsearch.indices.stats`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.recovery") + ensure_async_connection(es, "Index.stats") - return await es.indices.recovery(index=self._index, **kwargs) + return await es.indices.stats(index=self._index, **kwargs) - async def refresh(self, using=None, **kwargs): + async def segments(self, using=None, **kwargs): """ - Asynchronously performs a refresh operation on the index. + Provide low level segments information that a Lucene index (shard + level) is built with. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.refresh`` unchanged. + ``Elasticsearch.indices.segments`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.refresh") + ensure_async_connection(es, "Index.segments") - return await es.indices.refresh(index=self._index, **kwargs) + return await es.indices.segments(index=self._index, **kwargs) - async def save(self, using=None): + async def validate_query(self, using=None, **kwargs): """ - Asynchronously sync the index definition with Elasticsearch, creating the index - if it doesn't exist and updating its settings and mappings if it does. + Validate a potentially expensive query without executing it. - Note: Some settings and mapping changes cannot be done on an open index (or at - all on an existing index) and for those this method will fail with the - underlying exception. + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.validate_query`` unchanged. """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.validate_query") - if not await self.exists(using=using): - return await self.create(using=using) - - body = self.to_dict() - settings = body.pop("settings", {}) - analysis = settings.pop("analysis", None) - current_settings = self.get_settings(using=using)[self._name]["settings"][ - "index" - ] - - if analysis: - if await self.is_closed(using=using): - # closed index, update away - settings["analysis"] = analysis - else: - # compare analysis definition, if all analysis objects are - # already defined as requested, skip analysis update and - # proceed, otherwise raise IllegalOperation - existing_analysis = current_settings.get("analysis", {}) - if any( - existing_analysis.get(section, {}).get(k, None) - != analysis[section][k] - for section in analysis - for k in analysis[section] - ): - raise IllegalOperation( - "You cannot update analysis configuration on an open index, " - "you need to close index %s first." % self._name - ) + return await es.indices.validate_query(index=self._index, **kwargs) - # try and update the settings - if settings: - settings = settings.copy() - for k, v in list(settings.items()): - if k in current_settings and current_settings[k] == str(v): - del settings[k] + async def clear_cache(self, using=None, **kwargs): + """ + Clear all caches or specific cached associated with the index. - if settings: - await self.put_settings(using=using, body=settings) + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.clear_cache`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.clear_cache") - # update the mappings, any conflict in the mappings will result in an - # exception - mappings = body.pop("mappings", {}) - if mappings: - await self.put_mapping(using=using, body=mappings) + return await es.indices.clear_cache(index=self._index, **kwargs) - async def segments(self, using=None, **kwargs): + async def recovery(self, using=None, **kwargs): """ - Asynchronously provides low-level segments information that a Lucene index - (shard level) is built with. + The indices recovery API provides insight into on-going shard + recoveries for the index. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.segments`` unchanged. + ``Elasticsearch.indices.recovery`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.segments") + ensure_async_connection(es, "Index.recovery") - return await es.indices.segments(index=self._index, **kwargs) + return await es.indices.recovery(index=self._index, **kwargs) - async def shard_stores(self, using=None, **kwargs): + async def upgrade(self, using=None, **kwargs): """ - Asynchronously provides store information for shard copies of the index. Store - information reports on which nodes shard copies exist, the shard copy version, - indicating how recent they are, and any exceptions encountered while opening - the shard index or from earlier engine failure. + Upgrade the index to the latest format. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.shard_stores`` unchanged. + ``Elasticsearch.indices.upgrade`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.shard_stores") + ensure_async_connection(es, "Index.upgrade") - return await es.indices.shard_stores(index=self._index, **kwargs) + return await es.indices.upgrade(index=self._index, **kwargs) - async def shrink(self, using=None, **kwargs): + async def get_upgrade(self, using=None, **kwargs): """ - Asynchronously calls the shrink index API. - - The shrink index API allows you to shrink an existing index into a new - index with fewer primary shards. The number of primary shards in the - target index must be a factor of the shards in the source index. For - example an index with 8 primary shards can be shrunk into 4, 2 or 1 - primary shards or an index with 15 primary shards can be shrunk into 5, - 3 or 1. If the number of shards in the index is a prime number it can - only be shrunk into a single primary shard. Before shrinking, a - (primary or replica) copy of every shard in the index must be present - on the same node. + Monitor how much of the index is upgraded. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.shrink`` unchanged. + ``Elasticsearch.indices.get_upgrade`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.shrink") + ensure_async_connection(es, "Index.get_upgrade") - return await es.indices.shrink(index=self._index, **kwargs) + return await es.indices.get_upgrade(index=self._index, **kwargs) - async def stats(self, using=None, **kwargs): + async def flush_synced(self, using=None, **kwargs): """ - Asynchronously retrieves statistics on different operations happening on the - index. + Perform a normal flush, then add a generated unique marker (sync_id) to + all shards. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.stats`` unchanged. + ``Elasticsearch.indices.flush_synced`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.stats") + ensure_async_connection(es, "Index.flush_synced") - return await es.indices.stats(index=self._index, **kwargs) + return await es.indices.flush_synced(index=self._index, **kwargs) - async def upgrade(self, using=None, **kwargs): + async def shard_stores(self, using=None, **kwargs): """ - Asynchronously upgrades the index to the latest format. + Provides store information for shard copies of the index. Store + information reports on which nodes shard copies exist, the shard copy + version, indicating how recent they are, and any exceptions encountered + while opening the shard index or from earlier engine failure. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.upgrade`` unchanged. + ``Elasticsearch.indices.shard_stores`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.upgrade") + ensure_async_connection(es, "Index.shard_stores") - return await es.indices.upgrade(index=self._index, **kwargs) + return await es.indices.shard_stores(index=self._index, **kwargs) - async def validate_query(self, using=None, **kwargs): + async def forcemerge(self, using=None, **kwargs): """ - Asynchronously validates a potentially expensive query without executing it. + The force merge API allows to force merging of the index through an + API. The merge relates to the number of segments a Lucene index holds + within each shard. The force merge operation allows to reduce the + number of segments by merging them. + + This call will block until the merge is complete. If the http + connection is lost, the request will continue in the background, and + any new requests will block until the previous force merge is complete. Any additional keyword arguments will be passed to - ``AsyncElasticsearch.indices.validate_query`` unchanged. + ``Elasticsearch.indices.forcemerge`` unchanged. """ es = self._get_connection(using) - ensure_async_connection(es, "AsyncIndex.validate_query") + ensure_async_connection(es, "Index.forcemerge") - return await es.indices.validate_query(index=self._index, **kwargs) + return await es.indices.forcemerge(index=self._index, **kwargs) + async def shrink(self, using=None, **kwargs): + """ + The shrink index API allows you to shrink an existing index into a new + index with fewer primary shards. The number of primary shards in the + target index must be a factor of the shards in the source index. For + example an index with 8 primary shards can be shrunk into 4, 2 or 1 + primary shards or an index with 15 primary shards can be shrunk into 5, + 3 or 1. If the number of shards in the index is a prime number it can + only be shrunk into a single primary shard. Before shrinking, a + (primary or replica) copy of every shard in the index must be present + on the same node. -class AsyncIndexTemplate(IndexTemplate): - async def save(self, using=None): - es = get_connection(using or self._index._using) - ensure_async_connection(es, "AsyncIndexTemplate.save") + Any additional keyword arguments will be passed to + ``Elasticsearch.indices.shrink`` unchanged. + """ + es = self._get_connection(using) + ensure_async_connection(es, "Index.shrink") - return await es.indices.put_template( - name=self._template_name, body=self.to_dict() - ) + return await es.indices.shrink(index=self._index, **kwargs) diff --git a/elasticsearch_dsl/_async/mapping.py b/elasticsearch_dsl/_async/mapping.py index df9fc7818..ba099ec63 100644 --- a/elasticsearch_dsl/_async/mapping.py +++ b/elasticsearch_dsl/_async/mapping.py @@ -15,12 +15,97 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch_dsl._async.utils import ensure_async_connection +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc + +from itertools import chain + +from six import iteritems, itervalues + from elasticsearch_dsl.connections import get_connection -from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.field import Nested, Text, construct_field +from elasticsearch_dsl.utils import DslBase + +from .utils import ensure_async_connection + +META_FIELDS = frozenset( + ( + "dynamic", + "transform", + "dynamic_date_formats", + "date_detection", + "numeric_detection", + "dynamic_templates", + "enabled", + ) +) + + +class Properties(DslBase): + name = "properties" + _param_defs = {"properties": {"type": "field", "hash": True}} + + def __init__(self): + super(Properties, self).__init__() + + def __repr__(self): + return "Properties()" + + def __getitem__(self, name): + return self.properties[name] + def __contains__(self, name): + return name in self.properties + + def to_dict(self): + return super(Properties, self).to_dict()["properties"] + + def field(self, name, *args, **kwargs): + self.properties[name] = construct_field(*args, **kwargs) + return self + + def _collect_fields(self): + """ Iterate over all Field objects within, including multi fields. """ + for f in itervalues(self.properties.to_dict()): + yield f + # multi fields + if hasattr(f, "fields"): + for inner_f in itervalues(f.fields.to_dict()): + yield inner_f + # nested and inner objects + if hasattr(f, "_collect_fields"): + for inner_f in f._collect_fields(): + yield inner_f + + def update(self, other_object): + if not hasattr(other_object, "properties"): + # not an inner/nested object, no merge possible + return + + our, other = self.properties, other_object.properties + for name in other: + if name in our: + if hasattr(our[name], "update"): + our[name].update(other[name]) + continue + our[name] = other[name] + + +class Mapping(object): + def __init__(self): + self.properties = Properties() + self._meta = {} + + def __repr__(self): + return "Mapping()" + + def _clone(self): + m = Mapping() + m.properties._params = self.properties._params.copy() + return m -class AsyncMapping(Mapping): @classmethod async def from_es(cls, index, using="default"): m = cls() @@ -28,17 +113,132 @@ async def from_es(cls, index, using="default"): return m + def resolve_nested(self, field_path): + field = self + nested = [] + parts = field_path.split(".") + for i, step in enumerate(parts): + try: + field = field[step] + except KeyError: + return (), None + if isinstance(field, Nested): + nested.append(".".join(parts[: i + 1])) + return nested, field + + def resolve_field(self, field_path): + field = self + for step in field_path.split("."): + try: + field = field[step] + except KeyError: + return + return field + + def _collect_analysis(self): + analysis = {} + fields = [] + if "_all" in self._meta: + fields.append(Text(**self._meta["_all"])) + + for f in chain(fields, self.properties._collect_fields()): + for analyzer_name in ( + "analyzer", + "normalizer", + "search_analyzer", + "search_quote_analyzer", + ): + if not hasattr(f, analyzer_name): + continue + analyzer = getattr(f, analyzer_name) + d = analyzer.get_analysis_definition() + # empty custom analyzer, probably already defined out of our control + if not d: + continue + + # merge the definition + # TODO: conflict detection/resolution + for key in d: + analysis.setdefault(key, {}).update(d[key]) + + return analysis + async def save(self, index, using="default"): - from elasticsearch_dsl._async.index import AsyncIndex + from .index import Index - index = AsyncIndex(index, using=using) + index = Index(index, using=using) index.mapping(self) return await index.save() async def update_from_es(self, index, using="default"): es = get_connection(using) - ensure_async_connection(es, "AsyncMapping.update_from_es") + ensure_async_connection(es, "Mapping.update_from_es") raw = await es.indices.get_mapping(index=index) _, raw = raw.popitem() self._update_from_dict(raw["mappings"]) + + def _update_from_dict(self, raw): + for name, definition in iteritems(raw.get("properties", {})): + self.field(name, definition) + + # metadata like _all etc + for name, value in iteritems(raw): + if name != "properties": + if isinstance(value, collections_abc.Mapping): + self.meta(name, **value) + else: + self.meta(name, value) + + def update(self, mapping, update_only=False): + for name in mapping: + if update_only and name in self: + # nested and inner objects, merge recursively + if hasattr(self[name], "update"): + # FIXME only merge subfields, not the settings + self[name].update(mapping[name], update_only) + continue + self.field(name, mapping[name]) + + if update_only: + for name in mapping._meta: + if name not in self._meta: + self._meta[name] = mapping._meta[name] + else: + self._meta.update(mapping._meta) + + def __contains__(self, name): + return name in self.properties.properties + + def __getitem__(self, name): + return self.properties.properties[name] + + def __iter__(self): + return iter(self.properties.properties) + + def field(self, *args, **kwargs): + self.properties.field(*args, **kwargs) + return self + + def meta(self, name, params=None, **kwargs): + if not name.startswith("_") and name not in META_FIELDS: + name = "_" + name + + if params and kwargs: + raise ValueError("Meta configs cannot have both value and a dictionary.") + + self._meta[name] = kwargs if params is None else params + return self + + def to_dict(self): + meta = self._meta + + # hard coded serialization of analyzers in _all + if "_all" in meta: + meta = meta.copy() + _all = meta["_all"] = meta["_all"].copy() + for f in ("analyzer", "search_analyzer", "search_quote_analyzer"): + if hasattr(_all.get(f, None), "to_dict"): + _all[f] = _all[f].to_dict() + meta.update(self.properties.to_dict()) + return meta diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index b5c1cc170..e1a1136a2 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -15,47 +15,704 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch._async.helpers import async_scan +import copy -from elasticsearch_dsl._async.utils import ensure_async_connection +try: + import collections.abc as collections_abc # only works on python 3.3+ +except ImportError: + import collections as collections_abc + +from elasticsearch.exceptions import TransportError +from elasticsearch.helpers import async_scan +from six import iteritems, string_types + +from elasticsearch_dsl.aggs import A, AggBase from elasticsearch_dsl.connections import get_connection -from elasticsearch_dsl.search import MultiSearch, Search -from elasticsearch_dsl.utils import AttrDict +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import Hit, Response +from elasticsearch_dsl.utils import AttrDict, DslBase +from .utils import ensure_async_connection -class AsyncMultiSearch(MultiSearch): - async def execute(self, ignore_cache=False, raise_on_error=True): + +class QueryProxy(object): + """ + Simple proxy around DSL objects (queries) that can be called + (to add query/post_filter) and also allows attribute access which is proxied to + the wrapped query. + """ + + def __init__(self, search, attr_name): + self._search = search + self._proxied = None + self._attr_name = attr_name + + def __nonzero__(self): + return self._proxied is not None + + __bool__ = __nonzero__ + + def __call__(self, *args, **kwargs): + s = self._search._clone() + + # we cannot use self._proxied since we just cloned self._search and + # need to access the new self on the clone + proxied = getattr(s, self._attr_name) + if proxied._proxied is None: + proxied._proxied = Q(*args, **kwargs) + else: + proxied._proxied &= Q(*args, **kwargs) + + # always return search to be chainable + return s + + def __getattr__(self, attr_name): + return getattr(self._proxied, attr_name) + + def __setattr__(self, attr_name, value): + if not attr_name.startswith("_"): + self._proxied = Q(self._proxied.to_dict()) + setattr(self._proxied, attr_name, value) + super(QueryProxy, self).__setattr__(attr_name, value) + + def __getstate__(self): + return self._search, self._proxied, self._attr_name + + def __setstate__(self, state): + self._search, self._proxied, self._attr_name = state + + +class ProxyDescriptor(object): + """ + Simple descriptor to enable setting of queries and filters as: + + s = Search() + s.query = Q(...) + + """ + + def __init__(self, name): + self._attr_name = "_%s_proxy" % name + + def __get__(self, instance, owner): + return getattr(instance, self._attr_name) + + def __set__(self, instance, value): + proxy = getattr(instance, self._attr_name) + proxy._proxied = Q(value) + + +class AggsProxy(AggBase, DslBase): + name = "aggs" + + def __init__(self, search): + self._base = self + self._search = search + self._params = {"aggs": {}} + + def to_dict(self): + return super(AggsProxy, self).to_dict().get("aggs", {}) + + +class Request(object): + def __init__(self, using="default", index=None, doc_type=None, extra=None): + self._using = using + + self._index = None + if isinstance(index, (tuple, list)): + self._index = list(index) + elif index: + self._index = [index] + + self._doc_type = [] + self._doc_type_map = {} + if isinstance(doc_type, (tuple, list)): + self._doc_type.extend(doc_type) + elif isinstance(doc_type, collections_abc.Mapping): + self._doc_type.extend(doc_type.keys()) + self._doc_type_map.update(doc_type) + elif doc_type: + self._doc_type.append(doc_type) + + self._params = {} + self._extra = extra or {} + + def __eq__(self, other): + return ( + isinstance(other, Request) + and other._params == self._params + and other._index == self._index + and other._doc_type == self._doc_type + and other.to_dict() == self.to_dict() + ) + + def __copy__(self): + return self._clone() + + def params(self, **kwargs): """ - Execute the multi search request and return a list of search results. + Specify query params to be used when executing the search. All the + keyword arguments will override the current values. See + https://elasticsearch-py.readthedocs.io/en/master/api.html#elasticsearch.Elasticsearch.search + for all available parameters. + + Example:: + + s = Search() + s = s.params(routing='user-1', preference='local') """ - if ignore_cache or not hasattr(self, "_response"): - es = get_connection(self._using) - ensure_async_connection(es, "AsyncMultiSearch.execute") + s = self._clone() + s._params.update(kwargs) + return s - responses = await es.msearch( - index=self._index, - body=self.to_dict(), - **self.params, - ) + def index(self, *index): + """ + Set the index for the search. If called empty it will remove all information. + + Example: + + s = Search() + s = s.index('twitter-2015.01.01', 'twitter-2015.01.02') + s = s.index(['twitter-2015.01.01', 'twitter-2015.01.02']) + """ + # .index() resets + s = self._clone() + if not index: + s._index = None + else: + indexes = [] + for i in index: + if isinstance(i, string_types): + indexes.append(i) + elif isinstance(i, list): + indexes += i + elif isinstance(i, tuple): + indexes += list(i) + + s._index = (self._index or []) + indexes + + return s + + def _resolve_field(self, path): + for dt in self._doc_type: + if not hasattr(dt, "_index"): + continue + field = dt._index.resolve_field(path) + if field is not None: + return field + + def _resolve_nested(self, hit, parent_class=None): + doc_class = Hit + + nested_path = [] + nesting = hit["_nested"] + while nesting and "field" in nesting: + nested_path.append(nesting["field"]) + nesting = nesting.get("_nested") + nested_path = ".".join(nested_path) + + if hasattr(parent_class, "_index"): + nested_field = parent_class._index.resolve_field(nested_path) + else: + nested_field = self._resolve_field(nested_path) + + if nested_field is not None: + return nested_field._doc_class + + return doc_class - self._response = self._process_responses( - responses, raise_on_error=raise_on_error + def _get_result(self, hit, parent_class=None): + doc_class = Hit + dt = hit.get("_type") + + if "_nested" in hit: + doc_class = self._resolve_nested(hit, parent_class) + + elif dt in self._doc_type_map: + doc_class = self._doc_type_map[dt] + + else: + for doc_type in self._doc_type: + if hasattr(doc_type, "_matches") and doc_type._matches(hit): + doc_class = doc_type + break + + for t in hit.get("inner_hits", ()): + hit["inner_hits"][t] = Response( + self, hit["inner_hits"][t], doc_class=doc_class ) - return self._response + callback = getattr(doc_class, "from_es", doc_class) + return callback(hit) + + def doc_type(self, *doc_type, **kwargs): + """ + Set the type to search through. You can supply a single value or + multiple. Values can be strings or subclasses of ``Document``. + + You can also pass in any keyword arguments, mapping a doc_type to a + callback that should be used instead of the Hit class. + + If no doc_type is supplied any information stored on the instance will + be erased. + + Example: + + s = Search().doc_type('product', 'store', User, custom=my_callback) + """ + # .doc_type() resets + s = self._clone() + if not doc_type and not kwargs: + s._doc_type = [] + s._doc_type_map = {} + else: + s._doc_type.extend(doc_type) + s._doc_type.extend(kwargs.keys()) + s._doc_type_map.update(kwargs) + return s + + def using(self, client): + """ + Associate the search request with an elasticsearch client. A fresh copy + will be returned with current instance remaining unchanged. + :arg client: an instance of ``elasticsearch.Elasticsearch`` to use or + an alias to look up in ``elasticsearch_dsl.connections`` + + """ + s = self._clone() + s._using = client + return s + + def extra(self, **kwargs): + """ + Add extra keys to the request body. Mostly here for backwards + compatibility. + """ + s = self._clone() + if "from_" in kwargs: + kwargs["from"] = kwargs.pop("from_") + s._extra.update(kwargs) + return s + + def _clone(self): + s = self.__class__( + using=self._using, index=self._index, doc_type=self._doc_type + ) + s._doc_type_map = self._doc_type_map.copy() + s._extra = self._extra.copy() + s._params = self._params.copy() + return s + + +class Search(Request): + query = ProxyDescriptor("query") + post_filter = ProxyDescriptor("post_filter") + + def __init__(self, **kwargs): + """ + Search request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overridden by methods (`using`, `index` and `doc_type` respectively). + """ + super(Search, self).__init__(**kwargs) + + self.aggs = AggsProxy(self) + self._sort = [] + self._source = None + self._highlight = {} + self._highlight_opts = {} + self._suggest = {} + self._script_fields = {} + self._response_class = Response + + self._query_proxy = QueryProxy(self, "query") + self._post_filter_proxy = QueryProxy(self, "post_filter") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) -class AsyncSearch(Search): async def __aiter__(self): """ - Asynchronously iterates over the hits. + Iterate over the hits. """ return iter(self.execute()) + def __getitem__(self, n): + """ + Support slicing the `Search` instance for pagination. + + Slicing equates to the from/size parameters. E.g.:: + + s = Search().query(...)[0:25] + + is equivalent to:: + + s = Search().query(...).extra(from_=0, size=25) + + """ + s = self._clone() + + if isinstance(n, slice): + # If negative slicing, abort. + if n.start and n.start < 0 or n.stop and n.stop < 0: + raise ValueError("Search does not support negative slicing.") + # Elasticsearch won't get all results so we default to size: 10 if + # stop not given. + s._extra["from"] = n.start or 0 + s._extra["size"] = max( + 0, n.stop - (n.start or 0) if n.stop is not None else 10 + ) + return s + else: # This is an index lookup, equivalent to slicing by [n:n+1]. + # If negative index, abort. + if n < 0: + raise ValueError("Search does not support negative indexing.") + s._extra["from"] = n + s._extra["size"] = 1 + return s + + @classmethod + def from_dict(cls, d): + """ + Construct a new `Search` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + s = Search.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "aggs": {...} + }) + s = s.filter('term', published=True) + """ + s = cls() + s.update_from_dict(d) + return s + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + s = super(Search, self)._clone() + + s._response_class = self._response_class + s._sort = self._sort[:] + s._source = copy.copy(self._source) if self._source is not None else None + s._highlight = self._highlight.copy() + s._highlight_opts = self._highlight_opts.copy() + s._suggest = self._suggest.copy() + s._script_fields = self._script_fields.copy() + for x in ("query", "post_filter"): + getattr(s, x)._proxied = getattr(self, x)._proxied + + # copy top-level bucket definitions + if self.aggs._params.get("aggs"): + s.aggs._params = {"aggs": self.aggs._params["aggs"].copy()} + return s + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + s = self._clone() + s._response_class = cls + return s + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "post_filter" in d: + self.post_filter._proxied = Q(d.pop("post_filter")) + + aggs = d.pop("aggs", d.pop("aggregations", {})) + if aggs: + self.aggs._params = { + "aggs": {name: A(value) for (name, value) in iteritems(aggs)} + } + if "sort" in d: + self._sort = d.pop("sort") + if "_source" in d: + self._source = d.pop("_source") + if "highlight" in d: + high = d.pop("highlight").copy() + self._highlight = high.pop("fields") + self._highlight_opts = high + if "suggest" in d: + self._suggest = d.pop("suggest") + if "text" in self._suggest: + text = self._suggest.pop("text") + for s in self._suggest.values(): + s.setdefault("text", text) + if "script_fields" in d: + self._script_fields = d.pop("script_fields") + self._extra.update(d) + return self + + def script_fields(self, **kwargs): + """ + Define script fields to be calculated on hits. See + https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-script-fields.html + for more details. + + Example:: + + s = Search() + s = s.script_fields(times_two="doc['field'].value * 2") + s = s.script_fields( + times_three={ + 'script': { + 'lang': 'painless', + 'source': "doc['field'].value * params.n", + 'params': {'n': 3} + } + } + ) + + """ + s = self._clone() + for name in kwargs: + if isinstance(kwargs[name], string_types): + kwargs[name] = {"script": kwargs[name]} + s._script_fields.update(kwargs) + return s + + def source(self, fields=None, **kwargs): + """ + Selectively control how the _source field is returned. + + :arg fields: wildcard string, array of wildcards, or dictionary of includes and excludes + + If ``fields`` is None, the entire document will be returned for + each hit. If fields is a dictionary with keys of 'includes' and/or + 'excludes' the fields will be either included or excluded appropriately. + + Calling this multiple times with the same named parameter will override the + previous values with the new ones. + + Example:: + + s = Search() + s = s.source(includes=['obj1.*'], excludes=["*.description"]) + + s = Search() + s = s.source(includes=['obj1.*']).source(excludes=["*.description"]) + + """ + s = self._clone() + + if fields and kwargs: + raise ValueError("You cannot specify fields and kwargs at the same time.") + + if fields is not None: + s._source = fields + return s + + if kwargs and not isinstance(s._source, dict): + s._source = {} + + for key, value in kwargs.items(): + if value is None: + try: + del s._source[key] + except KeyError: + pass + else: + s._source[key] = value + + return s + + def sort(self, *keys): + """ + Add sorting information to the search request. If called without + arguments it will remove all sort requirements. Otherwise it will + replace them. Acceptable arguments are:: + + 'some.field' + '-some.other.field' + {'different.field': {'any': 'dict'}} + + so for example:: + + s = Search().sort( + 'category', + '-title', + {"price" : {"order" : "asc", "mode" : "avg"}} + ) + + will sort by ``category``, ``title`` (in descending order) and + ``price`` in ascending order using the ``avg`` mode. + + The API returns a copy of the Search object and can thus be chained. + """ + s = self._clone() + s._sort = [] + for k in keys: + if isinstance(k, string_types) and k.startswith("-"): + if k[1:] == "_score": + raise IllegalOperation("Sorting by `-_score` is not allowed.") + k = {k[1:]: {"order": "desc"}} + s._sort.append(k) + return s + + def highlight_options(self, **kwargs): + """ + Update the global highlighting options used for this request. For + example:: + + s = Search() + s = s.highlight_options(order='score') + """ + s = self._clone() + s._highlight_opts.update(kwargs) + return s + + def highlight(self, *fields, **kwargs): + """ + Request highlighting of some fields. All keyword arguments passed in will be + used as parameters for all the fields in the ``fields`` parameter. Example:: + + Search().highlight('title', 'body', fragment_size=50) + + will produce the equivalent of:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 50}, + "title": {"fragment_size": 50} + } + } + } + + If you want to have different options for different fields + you can call ``highlight`` twice:: + + Search().highlight('title', fragment_size=50).highlight('body', fragment_size=100) + + which will produce:: + + { + "highlight": { + "fields": { + "body": {"fragment_size": 100}, + "title": {"fragment_size": 50} + } + } + } + + """ + s = self._clone() + for f in fields: + s._highlight[f] = kwargs + return s + + def suggest(self, name, text, **kwargs): + """ + Add a suggestions request to the search. + + :arg name: name of the suggestion + :arg text: text to suggest on + + All keyword arguments will be added to the suggestions body. For example:: + + s = Search() + s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + """ + s = self._clone() + s._suggest[name] = {"text": text} + s._suggest[name].update(kwargs) + return s + + def to_dict(self, count=False, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request's body. + + :arg count: a flag to specify if we are interested in a body for count - + no aggregations, no pagination bounds etc. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + + if self.query: + d["query"] = self.query.to_dict() + + # count request doesn't care for sorting and other things + if not count: + if self.post_filter: + d["post_filter"] = self.post_filter.to_dict() + + if self.aggs.aggs: + d.update(self.aggs.to_dict()) + + if self._sort: + d["sort"] = self._sort + + d.update(self._extra) + + if self._source not in (None, {}): + d["_source"] = self._source + + if self._highlight: + d["highlight"] = {"fields": self._highlight} + d["highlight"].update(self._highlight_opts) + + if self._suggest: + d["suggest"] = self._suggest + + if self._script_fields: + d["script_fields"] = self._script_fields + + d.update(kwargs) + return d + + def count(self): + """ + Return the number of hits matching the query and filters. Note that + only the actual number is returned. + """ + if hasattr(self, "_response") and self._response.hits.total.relation == "eq": + return self._response.hits.total.value + + es = get_connection(self._using) + + d = self.to_dict(count=True) + # TODO: failed shards detection + return es.count(index=self._index, body=d, **self._params)["count"] + async def execute(self, ignore_cache=False): + """ + Execute the search and return an instance of ``Response`` wrapping all + the data. + + :arg ignore_cache: if set to ``True``, consecutive calls will hit + ES, while cached result will be ignored. Defaults to `False` + """ if ignore_cache or not hasattr(self, "_response"): es = get_connection(self._using) - ensure_async_connection(es, "AsyncSearch.execute") + ensure_async_connection(es, "Search.execute") self._response = self._response_class( self, @@ -75,7 +732,7 @@ async def scan(self): """ es = get_connection(self._using) - ensure_async_connection(es, "AsyncSearch.scan") + ensure_async_connection(es, "Search.scan") for hit in await async_scan( es, query=self.to_dict(), index=self._index, **self._params @@ -87,10 +744,85 @@ async def delete(self): delete() executes the query by delegating to delete_by_query() """ es = get_connection(self._using) - ensure_async_connection(es, "AsyncSearch.delete") + ensure_async_connection(es, "Search.delete") return AttrDict( await es.delete_by_query( index=self._index, body=self.to_dict(), **self._params ) ) + + +class MultiSearch(Request): + """ + Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single + request. + """ + + def __init__(self, **kwargs): + super(MultiSearch, self).__init__(**kwargs) + self._searches = [] + + def __getitem__(self, key): + return self._searches[key] + + def __iter__(self): + return iter(self._searches) + + def _clone(self): + ms = super(MultiSearch, self)._clone() + ms._searches = self._searches[:] + return ms + + def add(self, search): + """ + Adds a new :class:`~elasticsearch_dsl.Search` object to the request:: + + ms = MultiSearch(index='my-index') + ms = ms.add(Search(doc_type=Category).filter('term', category='python')) + ms = ms.add(Search(doc_type=Blog)) + """ + ms = self._clone() + ms._searches.append(search) + return ms + + def to_dict(self): + out = [] + for s in self._searches: + meta = {} + if s._index: + meta["index"] = s._index + meta.update(s._params) + + out.append(meta) + out.append(s.to_dict()) + + return out + + async def execute(self, ignore_cache=False, raise_on_error=True): + """ + Execute the multi search request and return a list of search results. + """ + if ignore_cache or not hasattr(self, "_response"): + es = get_connection(self._using) + ensure_async_connection(es, "MultiSearch.execute") + + responses = await es.msearch( + index=self._index, + body=self.to_dict(), + **self.params, + ) + + out = [] + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise TransportError("N/A", r["error"]["type"], r["error"]) + r = None + else: + r = Response(s, r) + out.append(r) + + self._response = out + + return self._response diff --git a/elasticsearch_dsl/_async/update_by_query.py b/elasticsearch_dsl/_async/update_by_query.py index 7a035e75a..5f1bf873b 100644 --- a/elasticsearch_dsl/_async/update_by_query.py +++ b/elasticsearch_dsl/_async/update_by_query.py @@ -15,19 +15,146 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch_dsl._async.utils import ensure_async_connection from elasticsearch_dsl.connections import get_connection -from elasticsearch_dsl.update_by_query import UpdateByQuery +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search import ProxyDescriptor, QueryProxy, Request +from .utils import ensure_async_connection + + +class UpdateByQuery(Request): + + query = ProxyDescriptor("query") + + def __init__(self, **kwargs): + """ + Update by query request to elasticsearch. + + :arg using: `Elasticsearch` instance to use + :arg index: limit the search to index + :arg doc_type: only query this type. + + All the parameters supplied (or omitted) at creation type can be later + overriden by methods (`using`, `index` and `doc_type` respectively). + + """ + super(UpdateByQuery, self).__init__(**kwargs) + self._response_class = UpdateByQueryResponse + self._script = {} + self._query_proxy = QueryProxy(self, "query") + + def filter(self, *args, **kwargs): + return self.query(Bool(filter=[Q(*args, **kwargs)])) + + def exclude(self, *args, **kwargs): + return self.query(Bool(filter=[~Q(*args, **kwargs)])) + + @classmethod + def from_dict(cls, d): + """ + Construct a new `UpdateByQuery` instance from a raw dict containing the search + body. Useful when migrating from raw dictionaries. + + Example:: + + ubq = UpdateByQuery.from_dict({ + "query": { + "bool": { + "must": [...] + } + }, + "script": {...} + }) + ubq = ubq.filter('term', published=True) + """ + u = cls() + u.update_from_dict(d) + return u + + def _clone(self): + """ + Return a clone of the current search request. Performs a shallow copy + of all the underlying objects. Used internally by most state modifying + APIs. + """ + ubq = super(UpdateByQuery, self)._clone() + + ubq._response_class = self._response_class + ubq._script = self._script.copy() + ubq.query._proxied = self.query._proxied + return ubq + + def response_class(self, cls): + """ + Override the default wrapper used for the response. + """ + ubq = self._clone() + ubq._response_class = cls + return ubq + + def update_from_dict(self, d): + """ + Apply options from a serialized body to the current instance. Modifies + the object in-place. Used mostly by ``from_dict``. + """ + d = d.copy() + if "query" in d: + self.query._proxied = Q(d.pop("query")) + if "script" in d: + self._script = d.pop("script") + self._extra.update(d) + return self + + def script(self, **kwargs): + """ + Define update action to take: + https://www.elastic.co/guide/en/elasticsearch/reference/current/modules-scripting-using.html + for more details. + + Note: the API only accepts a single script, so + calling the script multiple times will overwrite. + + Example:: + + ubq = Search() + ubq = ubq.script(source="ctx._source.likes++"") + ubq = ubq.script(source="ctx._source.likes += params.f"", + lang="expression", + params={'f': 3}) + """ + ubq = self._clone() + if ubq._script: + ubq._script = {} + ubq._script.update(kwargs) + return ubq + + def to_dict(self, **kwargs): + """ + Serialize the search into the dictionary that will be sent over as the + request'ubq body. + + All additional keyword arguments will be included into the dictionary. + """ + d = {} + if self.query: + d["query"] = self.query.to_dict() + + if self._script: + d["script"] = self._script + + d.update(self._extra) + + d.update(kwargs) + return d -class AsyncUpdateByQuery(UpdateByQuery): async def execute(self): """ Execute the search and return an instance of ``Response`` wrapping all the data. """ es = get_connection(self._using) - ensure_async_connection(es, "AsyncMultiSearch.execute") + ensure_async_connection(es, "MultiSearch.execute") self._response = self._response_class( self, diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index aa7aba1f4..248c393b8 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -25,26 +25,15 @@ from elasticsearch.exceptions import NotFoundError, RequestError from six import add_metaclass, iteritems -from .connections import get_connection -from .exceptions import IllegalOperation, ValidationException -from .field import Field -from .index import Index -from .mapping import Mapping -from .search import Search -from .utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException +from elasticsearch_dsl.field import Field +from elasticsearch_dsl.index import Index +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import DOC_META_FIELDS, META_FIELDS, ObjectBase, merge -try: - from elasticsearch import AsyncElasticsearch - - from elasticsearch_dsl._async.search import AsyncSearch -except ImportError: - # Async is not support for one of two reasons: - # - # 1. The Python version is less than 3.6, so elasticsearch-py doesn't expose - # it's async features. - # 2. The aiohttp package isn't installed, so elasticsearch-py doesn't expose - # it's async features. - pass +from .utils import ensure_sync_connection class MetaField(object): @@ -195,19 +184,9 @@ def search(cls, using=None, index=None): Create an :class:`~elasticsearch_dsl.Search` instance that will search over this ``Document``. """ - es = cls._get_using(using) - kwargs = { - "doc_type": [cls], - "index": cls._default_index(index), - } - - try: - if isinstance(es, AsyncElasticsearch): - return AsyncSearch(using=es, **kwargs) - except NameError: - pass - - return Search(using=es, **kwargs) + return Search( + using=cls._get_using(using), index=cls._default_index(index), doc_type=[cls] + ) @classmethod def get(cls, id, using=None, index=None, **kwargs): @@ -223,6 +202,8 @@ def get(cls, id, using=None, index=None, **kwargs): ``Elasticsearch.get`` unchanged. """ es = cls._get_connection(using) + ensure_sync_connection(es, "Document.get") + doc = es.get(index=cls._default_index(index), id=id, **kwargs) if not doc.get("found", False): return None @@ -251,18 +232,50 @@ def mget( """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - + es = cls._get_connection(using) + ensure_sync_connection(es, "Document.mget") - results = es.mget( - cls._build_mget_body(docs), index=cls._default_index(index), **kwargs - ) + body = { + "docs": [ + doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} + for doc in docs + ] + } + results = es.mget(body, index=cls._default_index(index), **kwargs) - return cls._parse_mget_results( - results, - missing=missing, - raise_on_error=raise_on_error, - ) + objs, error_docs, missing_docs = [], [], [] + for doc in results["docs"]: + if doc.get("found"): + if error_docs or missing_docs: + # We're going to raise an exception anyway, so avoid an + # expensive call to cls.from_es(). + continue + + objs.append(cls.from_es(doc)) + + elif doc.get("error"): + if raise_on_error: + error_docs.append(doc) + if missing == "none": + objs.append(None) + + # The doc didn't cause an error, but the doc also wasn't found. + elif missing == "raise": + missing_docs.append(doc) + elif missing == "none": + objs.append(None) + + if error_docs: + error_ids = [doc["_id"] for doc in error_docs] + message = "Required routing not provided for documents %s." + message %= ", ".join(error_ids) + raise RequestError(400, message, error_docs) + if missing_docs: + missing_ids = [doc["_id"] for doc in missing_docs] + message = "Documents %s not found." % ", ".join(missing_ids) + raise NotFoundError(404, message, {"docs": missing_docs}) + return objs def delete(self, using=None, index=None, **kwargs): """ @@ -276,7 +289,16 @@ def delete(self, using=None, index=None, **kwargs): ``Elasticsearch.delete`` unchanged. """ es = self._get_connection(using) - doc_meta = self._build_delete_doc_meta(**kwargs) + ensure_sync_connection(es, "Document.delete") + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} + + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) es.delete(index=self._get_index(index), **doc_meta) def to_dict(self, include_meta=False, skip_empty=True): @@ -345,94 +367,6 @@ def update( :return operation result noop/updated """ - body, doc_meta = self._build_update_body_and_meta( - detect_noop=detect_noop, - doc_as_upsert=doc_as_upsert, - retry_on_conflict=retry_on_conflict, - script=script, - script_id=script_id, - scripted_upsert=scripted_upsert, - upsert=upsert, - **fields - ) - - meta = self._get_connection(using).update( - index=self._get_index(index), body=body, refresh=refresh, **doc_meta - ) - self._update_doc_meta(meta) - - return meta["result"] - - def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): - """ - Save the document into elasticsearch. If the document doesn't exist it - is created, it is overwritten otherwise. Returns ``True`` if this - operations resulted in new document being created. - - :arg index: elasticsearch index to use, if the ``Document`` is - associated with an index this can be omitted. - :arg using: connection alias to use, defaults to ``'default'`` - :arg validate: set to ``False`` to skip validating the document - :arg skip_empty: if set to ``False`` will cause empty values (``None``, - ``[]``, ``{}``) to be left on the document. Those values will be - stripped out otherwise as they make no difference in elasticsearch. - - Any additional keyword arguments will be passed to - ``Elasticsearch.index`` unchanged. - - :return operation result created/updated - """ - if validate: - self.full_clean() - - es = self._get_connection(using) - doc_meta = self._build_save_doc_meta(**kwargs) - meta = es.index( - index=self._get_index(index), - body=self.to_dict(skip_empty=skip_empty), - **doc_meta - ) - self._update_doc_meta(meta) - - return meta["result"] - - def _build_delete_doc_meta(self, **kwargs): - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - doc_meta.update(kwargs) - - return doc_meta - - def _build_save_doc_meta(self, **kwargs): - # extract routing etc from meta - doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - - # Optimistic concurrency control - if "seq_no" in self.meta and "primary_term" in self.meta: - doc_meta["if_seq_no"] = self.meta["seq_no"] - doc_meta["if_primary_term"] = self.meta["primary_term"] - - doc_meta.update(kwargs) - - return doc_meta - - def _build_update_body_and_meta( - self, - detect_noop=True, - doc_as_upsert=False, - retry_on_conflict=None, - script=None, - script_id=None, - scripted_upsert=False, - upsert=None, - **fields - ): body = { "doc_as_upsert": doc_as_upsert, "detect_noop": detect_noop, @@ -481,55 +415,61 @@ def _build_update_body_and_meta( doc_meta["if_seq_no"] = self.meta["seq_no"] doc_meta["if_primary_term"] = self.meta["primary_term"] - return body, doc_meta + es = self._get_connection(using) + ensure_sync_connection(es, "Document.update") - @classmethod - def _build_mget_body(cls, docs): - return { - "docs": [ - doc if isinstance(doc, collections_abc.Mapping) else {"_id": doc} - for doc in docs - ] - } + meta = es.update( + index=self._get_index(index), body=body, refresh=refresh, **doc_meta + ) + # update meta information from ES + for k in META_FIELDS: + if "_" + k in meta: + setattr(self.meta, k, meta["_" + k]) - @classmethod - def _parse_mget_results(cls, results, missing="none", raise_on_error=True): - objs, error_docs, missing_docs = [], [], [] - for doc in results["docs"]: - if doc.get("found"): - if error_docs or missing_docs: - # We're going to raise an exception anyway, so avoid an - # expensive call to cls.from_es(). - continue + return meta["result"] - objs.append(cls.from_es(doc)) + def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): + """ + Save the document into elasticsearch. If the document doesn't exist it + is created, it is overwritten otherwise. Returns ``True`` if this + operations resulted in new document being created. - elif doc.get("error"): - if raise_on_error: - error_docs.append(doc) - if missing == "none": - objs.append(None) + :arg index: elasticsearch index to use, if the ``Document`` is + associated with an index this can be omitted. + :arg using: connection alias to use, defaults to ``'default'`` + :arg validate: set to ``False`` to skip validating the document + :arg skip_empty: if set to ``False`` will cause empty values (``None``, + ``[]``, ``{}``) to be left on the document. Those values will be + stripped out otherwise as they make no difference in elasticsearch. - # The doc didn't cause an error, but the doc also wasn't found. - elif missing == "raise": - missing_docs.append(doc) - elif missing == "none": - objs.append(None) + Any additional keyword arguments will be passed to + ``Elasticsearch.index`` unchanged. - if error_docs: - error_ids = [doc["_id"] for doc in error_docs] - message = "Required routing not provided for documents %s." - message %= ", ".join(error_ids) - raise RequestError(400, message, error_docs) - if missing_docs: - missing_ids = [doc["_id"] for doc in missing_docs] - message = "Documents %s not found." % ", ".join(missing_ids) - raise NotFoundError(404, message, {"docs": missing_docs}) + :return operation result created/updated + """ + if validate: + self.full_clean() - return objs + es = self._get_connection(using) + ensure_sync_connection(es, "Document.save") + + # extract routing etc from meta + doc_meta = {k: self.meta[k] for k in DOC_META_FIELDS if k in self.meta} - def _update_doc_meta(self, meta): + # Optimistic concurrency control + if "seq_no" in self.meta and "primary_term" in self.meta: + doc_meta["if_seq_no"] = self.meta["seq_no"] + doc_meta["if_primary_term"] = self.meta["primary_term"] + + doc_meta.update(kwargs) + meta = es.index( + index=self._get_index(index), + body=self.to_dict(skip_empty=skip_empty), + **doc_meta + ) # update meta information from ES for k in META_FIELDS: if "_" + k in meta: setattr(self.meta, k, meta["_" + k]) + + return meta["result"] diff --git a/elasticsearch_dsl/faceted_search.py b/elasticsearch_dsl/faceted_search.py index 9c653a85c..39d1b00ec 100644 --- a/elasticsearch_dsl/faceted_search.py +++ b/elasticsearch_dsl/faceted_search.py @@ -19,11 +19,11 @@ from six import iteritems, itervalues -from .aggs import A -from .query import MatchAll, Nested, Range, Terms -from .response import Response -from .search import Search -from .utils import AttrDict +from elasticsearch_dsl.aggs import A +from elasticsearch_dsl.query import MatchAll, Nested, Range, Terms +from elasticsearch_dsl.response import Response +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.utils import AttrDict __all__ = [ "FacetedSearch", diff --git a/elasticsearch_dsl/index.py b/elasticsearch_dsl/index.py index 17dd93f45..0becbcea4 100644 --- a/elasticsearch_dsl/index.py +++ b/elasticsearch_dsl/index.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. -from . import analysis -from .connections import get_connection -from .exceptions import IllegalOperation -from .mapping import Mapping -from .search import Search -from .update_by_query import UpdateByQuery -from .utils import merge +from elasticsearch_dsl import analysis +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.mapping import Mapping +from elasticsearch_dsl.search import Search +from elasticsearch_dsl.update_by_query import UpdateByQuery +from elasticsearch_dsl.utils import merge + +from .utils import ensure_sync_connection class IndexTemplate(object): @@ -50,9 +52,12 @@ def to_dict(self): return d def save(self, using=None): - es = get_connection(using or self._index._using) - return es.indices.put_template(name=self._template_name, body=self.to_dict()) + ensure_sync_connection(es, "IndexTemplate.save") + + return es.indices.put_template( + name=self._template_name, body=self.to_dict() + ) class Index(object): @@ -101,9 +106,9 @@ def resolve_field(self, field_path): return None def load_mappings(self, using=None): - self.get_or_create_mapping().update_from_es( - self._name, using=using or self._using - ) + mapping = self.get_or_create_mapping() + + mapping.update_from_es(self._name, using=using or self._using) def clone(self, name=None, using=None): """ @@ -276,14 +281,24 @@ def create(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.create`` unchanged. """ - return self._get_connection(using).indices.create( - index=self._name, body=self.to_dict(), **kwargs + es = get_connection(using) + ensure_sync_connection(es, "Index.create") + + return es.indices.create( + index=self._name, + body=self.to_dict(), + **kwargs, ) def is_closed(self, using=None): - state = self._get_connection(using).cluster.state( - index=self._name, metric="metadata" + es = get_connection(using) + ensure_sync_connection(es, "Index.is_closed") + + state = es.cluster.state( + index=self._name, + metric="metadata", ) + return state["metadata"]["indices"][self._name]["state"] == "close" def save(self, using=None): @@ -348,7 +363,10 @@ def analyze(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.analyze`` unchanged. """ - return self._get_connection(using).indices.analyze(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.analyze") + + return es.indices.analyze(index=self._index, **kwargs) def refresh(self, using=None, **kwargs): """ @@ -357,7 +375,10 @@ def refresh(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.refresh`` unchanged. """ - return self._get_connection(using).indices.refresh(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.refresh") + + return es.indices.refresh(index=self._index, **kwargs) def flush(self, using=None, **kwargs): """ @@ -366,7 +387,10 @@ def flush(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.flush`` unchanged. """ - return self._get_connection(using).indices.flush(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.flush") + + return es.indices.flush(index=self._index, **kwargs) def get(self, using=None, **kwargs): """ @@ -375,7 +399,10 @@ def get(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get`` unchanged. """ - return self._get_connection(using).indices.get(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get") + + return es.indices.get(index=self._index, **kwargs) def open(self, using=None, **kwargs): """ @@ -384,7 +411,10 @@ def open(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.open`` unchanged. """ - return self._get_connection(using).indices.open(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.open") + + return es.indices.open(index=self._index, **kwargs) def close(self, using=None, **kwargs): """ @@ -393,7 +423,10 @@ def close(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.close`` unchanged. """ - return self._get_connection(using).indices.close(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.close") + + return es.indices.close(index=self._index, **kwargs) def delete(self, using=None, **kwargs): """ @@ -402,7 +435,10 @@ def delete(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.delete`` unchanged. """ - return self._get_connection(using).indices.delete(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.delete") + + return es.indices.delete(index=self._index, **kwargs) def exists(self, using=None, **kwargs): """ @@ -411,7 +447,10 @@ def exists(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.exists`` unchanged. """ - return self._get_connection(using).indices.exists(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.exists") + + return es.indices.exists(index=self._index, **kwargs) def exists_type(self, using=None, **kwargs): """ @@ -420,9 +459,10 @@ def exists_type(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.exists_type`` unchanged. """ - return self._get_connection(using).indices.exists_type( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.exists_type") + + return es.indices.exists_type(index=self._index, **kwargs) def put_mapping(self, using=None, **kwargs): """ @@ -431,9 +471,10 @@ def put_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_mapping`` unchanged. """ - return self._get_connection(using).indices.put_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_mapping") + + return es.indices.put_mapping(index=self._index, **kwargs) def get_mapping(self, using=None, **kwargs): """ @@ -442,9 +483,10 @@ def get_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_mapping`` unchanged. """ - return self._get_connection(using).indices.get_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_mapping") + + return es.indices.get_mapping(index=self._index, **kwargs) def get_field_mapping(self, using=None, **kwargs): """ @@ -453,9 +495,10 @@ def get_field_mapping(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_field_mapping`` unchanged. """ - return self._get_connection(using).indices.get_field_mapping( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_field_mapping") + + return es.indices.get_field_mapping(index=self._index, **kwargs) def put_alias(self, using=None, **kwargs): """ @@ -464,7 +507,10 @@ def put_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_alias`` unchanged. """ - return self._get_connection(using).indices.put_alias(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_alias") + + return es.indices.put_alias(index=self._index, **kwargs) def exists_alias(self, using=None, **kwargs): """ @@ -484,7 +530,10 @@ def get_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_alias`` unchanged. """ - return self._get_connection(using).indices.get_alias(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_alias") + + return es.indices.get_alias(index=self._index, **kwargs) def delete_alias(self, using=None, **kwargs): """ @@ -493,9 +542,10 @@ def delete_alias(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.delete_alias`` unchanged. """ - return self._get_connection(using).indices.delete_alias( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.delete_alias") + + return es.indices.delete_alias(index=self._index, **kwargs) def get_settings(self, using=None, **kwargs): """ @@ -504,9 +554,10 @@ def get_settings(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_settings`` unchanged. """ - return self._get_connection(using).indices.get_settings( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_settings") + + return es.indices.get_settings(index=self._index, **kwargs) def put_settings(self, using=None, **kwargs): """ @@ -515,9 +566,10 @@ def put_settings(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.put_settings`` unchanged. """ - return self._get_connection(using).indices.put_settings( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.put_settings") + + return es.indices.put_settings(index=self._index, **kwargs) def stats(self, using=None, **kwargs): """ @@ -526,7 +578,10 @@ def stats(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.stats`` unchanged. """ - return self._get_connection(using).indices.stats(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.stats") + + return es.indices.stats(index=self._index, **kwargs) def segments(self, using=None, **kwargs): """ @@ -536,7 +591,10 @@ def segments(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.segments`` unchanged. """ - return self._get_connection(using).indices.segments(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.segments") + + return es.indices.segments(index=self._index, **kwargs) def validate_query(self, using=None, **kwargs): """ @@ -545,9 +603,10 @@ def validate_query(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.validate_query`` unchanged. """ - return self._get_connection(using).indices.validate_query( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.validate_query") + + return es.indices.validate_query(index=self._index, **kwargs) def clear_cache(self, using=None, **kwargs): """ @@ -556,9 +615,10 @@ def clear_cache(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.clear_cache`` unchanged. """ - return self._get_connection(using).indices.clear_cache( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.clear_cache") + + return es.indices.clear_cache(index=self._index, **kwargs) def recovery(self, using=None, **kwargs): """ @@ -568,7 +628,10 @@ def recovery(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.recovery`` unchanged. """ - return self._get_connection(using).indices.recovery(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.recovery") + + return es.indices.recovery(index=self._index, **kwargs) def upgrade(self, using=None, **kwargs): """ @@ -577,7 +640,10 @@ def upgrade(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.upgrade`` unchanged. """ - return self._get_connection(using).indices.upgrade(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.upgrade") + + return es.indices.upgrade(index=self._index, **kwargs) def get_upgrade(self, using=None, **kwargs): """ @@ -586,9 +652,10 @@ def get_upgrade(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.get_upgrade`` unchanged. """ - return self._get_connection(using).indices.get_upgrade( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.get_upgrade") + + return es.indices.get_upgrade(index=self._index, **kwargs) def flush_synced(self, using=None, **kwargs): """ @@ -598,9 +665,10 @@ def flush_synced(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.flush_synced`` unchanged. """ - return self._get_connection(using).indices.flush_synced( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.flush_synced") + + return es.indices.flush_synced(index=self._index, **kwargs) def shard_stores(self, using=None, **kwargs): """ @@ -612,9 +680,10 @@ def shard_stores(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.shard_stores`` unchanged. """ - return self._get_connection(using).indices.shard_stores( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.shard_stores") + + return es.indices.shard_stores(index=self._index, **kwargs) def forcemerge(self, using=None, **kwargs): """ @@ -630,9 +699,10 @@ def forcemerge(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.forcemerge`` unchanged. """ - return self._get_connection(using).indices.forcemerge( - index=self._name, **kwargs - ) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.forcemerge") + + return es.indices.forcemerge(index=self._index, **kwargs) def shrink(self, using=None, **kwargs): """ @@ -649,4 +719,7 @@ def shrink(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.shrink`` unchanged. """ - return self._get_connection(using).indices.shrink(index=self._name, **kwargs) + es = self._get_connection(using) + ensure_sync_connection(es, "Index.shrink") + + return es.indices.shrink(index=self._index, **kwargs) diff --git a/elasticsearch_dsl/mapping.py b/elasticsearch_dsl/mapping.py index 6d1bc8bfd..da8f7d710 100644 --- a/elasticsearch_dsl/mapping.py +++ b/elasticsearch_dsl/mapping.py @@ -24,9 +24,11 @@ from six import iteritems, itervalues -from .connections import get_connection -from .field import Nested, Text, construct_field -from .utils import DslBase +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.field import Nested, Text, construct_field +from elasticsearch_dsl.utils import DslBase + +from .utils import ensure_sync_connection META_FIELDS = frozenset( ( @@ -108,6 +110,7 @@ def _clone(self): def from_es(cls, index, using="default"): m = cls() m.update_from_es(index, using) + return m def resolve_nested(self, field_path): @@ -169,6 +172,8 @@ def save(self, index, using="default"): def update_from_es(self, index, using="default"): es = get_connection(using) + ensure_sync_connection(es, "Mapping.update_from_es") + raw = es.indices.get_mapping(index=index) _, raw = raw.popitem() self._update_from_dict(raw["mappings"]) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 7591a1584..29e1f31fb 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -26,12 +26,14 @@ from elasticsearch.helpers import scan from six import iteritems, string_types -from .aggs import A, AggBase -from .connections import get_connection -from .exceptions import IllegalOperation -from .query import Bool, Q -from .response import Hit, Response -from .utils import AttrDict, DslBase +from elasticsearch_dsl.aggs import A, AggBase +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.exceptions import IllegalOperation +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import Hit, Response +from elasticsearch_dsl.utils import AttrDict, DslBase + +from .utils import ensure_sync_connection class QueryProxy(object): @@ -710,10 +712,13 @@ def execute(self, ignore_cache=False): """ if ignore_cache or not hasattr(self, "_response"): es = get_connection(self._using) + ensure_sync_connection(es, "Search.execute") self._response = self._response_class( - self, es.search(index=self._index, body=self.to_dict(), **self._params) + self, + es.search(index=self._index, body=self.to_dict(), **self._params), ) + return self._response def scan(self): @@ -727,19 +732,24 @@ def scan(self): """ es = get_connection(self._using) + ensure_sync_connection(es, "Search.scan") - for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): + for hit in scan( + es, query=self.to_dict(), index=self._index, **self._params + ): yield self._get_result(hit) def delete(self): """ delete() executes the query by delegating to delete_by_query() """ - es = get_connection(self._using) + ensure_sync_connection(es, "Search.delete") return AttrDict( - es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) + es.delete_by_query( + index=self._index, body=self.to_dict(), **self._params + ) ) @@ -795,28 +805,24 @@ def execute(self, ignore_cache=False, raise_on_error=True): """ if ignore_cache or not hasattr(self, "_response"): es = get_connection(self._using) + ensure_sync_connection(es, "MultiSearch.execute") responses = es.msearch( - index=self._index, body=self.to_dict(), **self._params + index=self._index, + body=self.to_dict(), + **self.params, ) - self._response = self._process_responses( - responses, raise_on_error=raise_on_error - ) + out = [] + for s, r in zip(self._searches, responses["responses"]): + if r.get("error", False): + if raise_on_error: + raise TransportError("N/A", r["error"]["type"], r["error"]) + r = None + else: + r = Response(s, r) + out.append(r) - return self._response + self._response = out - def _process_responses(self, responses, raise_on_error=True): - out = [] - - for s, r in zip(self._searches, responses["responses"]): - if r.get("error", False): - if raise_on_error: - raise TransportError("N/A", r["error"]["type"], r["error"]) - r = None - else: - r = Response(s, r) - - out.append(r) - - return out + return self._response diff --git a/elasticsearch_dsl/update_by_query.py b/elasticsearch_dsl/update_by_query.py index 1d257b92f..3c5a4f943 100644 --- a/elasticsearch_dsl/update_by_query.py +++ b/elasticsearch_dsl/update_by_query.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. -from .connections import get_connection -from .query import Bool, Q -from .response import UpdateByQueryResponse -from .search import ProxyDescriptor, QueryProxy, Request +from elasticsearch_dsl.connections import get_connection +from elasticsearch_dsl.query import Bool, Q +from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search import ProxyDescriptor, QueryProxy, Request + +from .utils import ensure_sync_connection class UpdateByQuery(Request): @@ -152,9 +154,12 @@ def execute(self): the data. """ es = get_connection(self._using) + ensure_sync_connection(es, "SyncMultiSearch.execute") self._response = self._response_class( self, - es.update_by_query(index=self._index, body=self.to_dict(), **self._params), + es.update_by_query( + index=self._index, body=self.to_dict(), **self._params + ), ) return self._response diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 50849773a..43f33d597 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -24,6 +24,7 @@ from copy import copy +from elasticsearch import Elasticsearch from six import add_metaclass, iteritems from six.moves import map @@ -544,6 +545,14 @@ def full_clean(self): self.clean() +def ensure_sync_connection(es, fn_label): + if not isinstance(es, Elasticsearch): + raise TypeError( + f"{fn_label} can only be used with the elasticsearch.Elasticsearch " + "client" + ) + + def merge(data, new_data, raise_on_conflict=False): if not ( isinstance(data, (AttrDict, collections_abc.Mapping)) diff --git a/setup.py b/setup.py index 92fc7e73e..797ebbf4e 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,8 @@ tests_require=tests_require + async_requires, extras_require={ "async": async_requires, - "develop": tests_require + async_requires + ["sphinx", "sphinx_rtd_theme"], + "develop": ( + tests_require + async_requires + ["sphinx", "sphinx_rtd_theme", "unasync"] + ), }, ) diff --git a/utils/generate-sync.py b/utils/generate-sync.py new file mode 100644 index 000000000..ff40b5efe --- /dev/null +++ b/utils/generate-sync.py @@ -0,0 +1,42 @@ +import os +from pathlib import Path + +import unasync + +CODE_ROOT = Path(__file__).absolute().parent.parent + + +def generate_sync(): + additional_replacements = { + "_async": "", + "async_scan": "scan", + "ensure_async_connection": "ensure_sync_connection", + } + + rules = [ + unasync.Rule( + fromdir="/_async/", + todir="/", + additional_replacements=additional_replacements, + ), + ] + + filepaths = [] + for root, _, filenames in os.walk(CODE_ROOT / "elasticsearch_dsl/_async"): + for filename in filenames: + if ( + filename.rpartition(".")[-1] + in ( + "py", + "pyi", + ) + and not filename.startswith("__init__.py") + and not filename.startswith("utils.py") + ): + filepaths.append(os.path.join(root, filename)) + + unasync.unasync_files(filepaths, rules) + + +if __name__ == '__main__': + generate_sync() From 9db58b37ab7a64d79c9ddddfe6ff2e1c27bf97d5 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 23 Oct 2020 10:41:00 -0600 Subject: [PATCH 10/11] Fix broken build --- elasticsearch_dsl/_async/__init__.py | 16 +++++++ elasticsearch_dsl/_async/document.py | 6 ++- elasticsearch_dsl/_async/index.py | 62 ++++++++++++-------------- elasticsearch_dsl/_async/search.py | 4 +- elasticsearch_dsl/document.py | 2 +- elasticsearch_dsl/index.py | 66 +++++++++++++--------------- elasticsearch_dsl/search.py | 12 ++--- elasticsearch_dsl/update_by_query.py | 6 +-- elasticsearch_dsl/utils.py | 9 +++- utils/generate-sync.py | 28 +++++++++++- 10 files changed, 120 insertions(+), 91 deletions(-) diff --git a/elasticsearch_dsl/_async/__init__.py b/elasticsearch_dsl/_async/__init__.py index e69de29bb..2a87d183f 100644 --- a/elasticsearch_dsl/_async/__init__.py +++ b/elasticsearch_dsl/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py index 77cb0d68a..98dc56893 100644 --- a/elasticsearch_dsl/_async/document.py +++ b/elasticsearch_dsl/_async/document.py @@ -232,7 +232,7 @@ async def mget( """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - + es = cls._get_connection(using) ensure_async_connection(es, "Document.mget") @@ -428,7 +428,9 @@ async def update( return meta["result"] - async def save(self, using=None, index=None, validate=True, skip_empty=True, **kwargs): + async def save( + self, using=None, index=None, validate=True, skip_empty=True, **kwargs + ): """ Save the document into elasticsearch. If the document doesn't exist it is created, it is overwritten otherwise. Returns ``True`` if this diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index 570dcdc6b..df5c93c9e 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -284,11 +284,7 @@ async def create(self, using=None, **kwargs): es = get_connection(using) ensure_async_connection(es, "Index.create") - return await es.indices.create( - index=self._name, - body=self.to_dict(), - **kwargs, - ) + return await es.indices.create(index=self._name, body=self.to_dict(), **kwargs) async def is_closed(self, using=None): es = get_connection(using) @@ -366,7 +362,7 @@ async def analyze(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.analyze") - return await es.indices.analyze(index=self._index, **kwargs) + return await es.indices.analyze(index=self._name, **kwargs) async def refresh(self, using=None, **kwargs): """ @@ -378,7 +374,7 @@ async def refresh(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.refresh") - return await es.indices.refresh(index=self._index, **kwargs) + return await es.indices.refresh(index=self._name, **kwargs) async def flush(self, using=None, **kwargs): """ @@ -390,7 +386,7 @@ async def flush(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.flush") - return await es.indices.flush(index=self._index, **kwargs) + return await es.indices.flush(index=self._name, **kwargs) async def get(self, using=None, **kwargs): """ @@ -402,7 +398,7 @@ async def get(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get") - return await es.indices.get(index=self._index, **kwargs) + return await es.indices.get(index=self._name, **kwargs) async def open(self, using=None, **kwargs): """ @@ -414,7 +410,7 @@ async def open(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.open") - return await es.indices.open(index=self._index, **kwargs) + return await es.indices.open(index=self._name, **kwargs) async def close(self, using=None, **kwargs): """ @@ -426,7 +422,7 @@ async def close(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.close") - return await es.indices.close(index=self._index, **kwargs) + return await es.indices.close(index=self._name, **kwargs) async def delete(self, using=None, **kwargs): """ @@ -438,7 +434,7 @@ async def delete(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.delete") - return await es.indices.delete(index=self._index, **kwargs) + return await es.indices.delete(index=self._name, **kwargs) async def exists(self, using=None, **kwargs): """ @@ -450,7 +446,7 @@ async def exists(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.exists") - return await es.indices.exists(index=self._index, **kwargs) + return await es.indices.exists(index=self._name, **kwargs) async def exists_type(self, using=None, **kwargs): """ @@ -462,7 +458,7 @@ async def exists_type(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.exists_type") - return await es.indices.exists_type(index=self._index, **kwargs) + return await es.indices.exists_type(index=self._name, **kwargs) async def put_mapping(self, using=None, **kwargs): """ @@ -474,7 +470,7 @@ async def put_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.put_mapping") - return await es.indices.put_mapping(index=self._index, **kwargs) + return await es.indices.put_mapping(index=self._name, **kwargs) async def get_mapping(self, using=None, **kwargs): """ @@ -486,7 +482,7 @@ async def get_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get_mapping") - return await es.indices.get_mapping(index=self._index, **kwargs) + return await es.indices.get_mapping(index=self._name, **kwargs) async def get_field_mapping(self, using=None, **kwargs): """ @@ -498,7 +494,7 @@ async def get_field_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get_field_mapping") - return await es.indices.get_field_mapping(index=self._index, **kwargs) + return await es.indices.get_field_mapping(index=self._name, **kwargs) async def put_alias(self, using=None, **kwargs): """ @@ -510,7 +506,7 @@ async def put_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.put_alias") - return await es.indices.put_alias(index=self._index, **kwargs) + return await es.indices.put_alias(index=self._name, **kwargs) def exists_alias(self, using=None, **kwargs): """ @@ -533,7 +529,7 @@ async def get_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get_alias") - return await es.indices.get_alias(index=self._index, **kwargs) + return await es.indices.get_alias(index=self._name, **kwargs) async def delete_alias(self, using=None, **kwargs): """ @@ -545,7 +541,7 @@ async def delete_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.delete_alias") - return await es.indices.delete_alias(index=self._index, **kwargs) + return await es.indices.delete_alias(index=self._name, **kwargs) async def get_settings(self, using=None, **kwargs): """ @@ -557,7 +553,7 @@ async def get_settings(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get_settings") - return await es.indices.get_settings(index=self._index, **kwargs) + return await es.indices.get_settings(index=self._name, **kwargs) async def put_settings(self, using=None, **kwargs): """ @@ -569,7 +565,7 @@ async def put_settings(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.put_settings") - return await es.indices.put_settings(index=self._index, **kwargs) + return await es.indices.put_settings(index=self._name, **kwargs) async def stats(self, using=None, **kwargs): """ @@ -581,7 +577,7 @@ async def stats(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.stats") - return await es.indices.stats(index=self._index, **kwargs) + return await es.indices.stats(index=self._name, **kwargs) async def segments(self, using=None, **kwargs): """ @@ -594,7 +590,7 @@ async def segments(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.segments") - return await es.indices.segments(index=self._index, **kwargs) + return await es.indices.segments(index=self._name, **kwargs) async def validate_query(self, using=None, **kwargs): """ @@ -606,7 +602,7 @@ async def validate_query(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.validate_query") - return await es.indices.validate_query(index=self._index, **kwargs) + return await es.indices.validate_query(index=self._name, **kwargs) async def clear_cache(self, using=None, **kwargs): """ @@ -618,7 +614,7 @@ async def clear_cache(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.clear_cache") - return await es.indices.clear_cache(index=self._index, **kwargs) + return await es.indices.clear_cache(index=self._name, **kwargs) async def recovery(self, using=None, **kwargs): """ @@ -631,7 +627,7 @@ async def recovery(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.recovery") - return await es.indices.recovery(index=self._index, **kwargs) + return await es.indices.recovery(index=self._name, **kwargs) async def upgrade(self, using=None, **kwargs): """ @@ -643,7 +639,7 @@ async def upgrade(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.upgrade") - return await es.indices.upgrade(index=self._index, **kwargs) + return await es.indices.upgrade(index=self._name, **kwargs) async def get_upgrade(self, using=None, **kwargs): """ @@ -655,7 +651,7 @@ async def get_upgrade(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.get_upgrade") - return await es.indices.get_upgrade(index=self._index, **kwargs) + return await es.indices.get_upgrade(index=self._name, **kwargs) async def flush_synced(self, using=None, **kwargs): """ @@ -668,7 +664,7 @@ async def flush_synced(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.flush_synced") - return await es.indices.flush_synced(index=self._index, **kwargs) + return await es.indices.flush_synced(index=self._name, **kwargs) async def shard_stores(self, using=None, **kwargs): """ @@ -683,7 +679,7 @@ async def shard_stores(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.shard_stores") - return await es.indices.shard_stores(index=self._index, **kwargs) + return await es.indices.shard_stores(index=self._name, **kwargs) async def forcemerge(self, using=None, **kwargs): """ @@ -702,7 +698,7 @@ async def forcemerge(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.forcemerge") - return await es.indices.forcemerge(index=self._index, **kwargs) + return await es.indices.forcemerge(index=self._name, **kwargs) async def shrink(self, using=None, **kwargs): """ @@ -722,4 +718,4 @@ async def shrink(self, using=None, **kwargs): es = self._get_connection(using) ensure_async_connection(es, "Index.shrink") - return await es.indices.shrink(index=self._index, **kwargs) + return await es.indices.shrink(index=self._name, **kwargs) diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index e1a1136a2..ca7073973 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -808,9 +808,7 @@ async def execute(self, ignore_cache=False, raise_on_error=True): ensure_async_connection(es, "MultiSearch.execute") responses = await es.msearch( - index=self._index, - body=self.to_dict(), - **self.params, + index=self._index, body=self.to_dict(), **self.params ) out = [] diff --git a/elasticsearch_dsl/document.py b/elasticsearch_dsl/document.py index 248c393b8..6c250e9bf 100644 --- a/elasticsearch_dsl/document.py +++ b/elasticsearch_dsl/document.py @@ -232,7 +232,7 @@ def mget( """ if missing not in ("raise", "skip", "none"): raise ValueError("'missing' must be 'raise', 'skip', or 'none'.") - + es = cls._get_connection(using) ensure_sync_connection(es, "Document.mget") diff --git a/elasticsearch_dsl/index.py b/elasticsearch_dsl/index.py index 0becbcea4..04bf77478 100644 --- a/elasticsearch_dsl/index.py +++ b/elasticsearch_dsl/index.py @@ -55,9 +55,7 @@ def save(self, using=None): es = get_connection(using or self._index._using) ensure_sync_connection(es, "IndexTemplate.save") - return es.indices.put_template( - name=self._template_name, body=self.to_dict() - ) + return es.indices.put_template(name=self._template_name, body=self.to_dict()) class Index(object): @@ -284,11 +282,7 @@ def create(self, using=None, **kwargs): es = get_connection(using) ensure_sync_connection(es, "Index.create") - return es.indices.create( - index=self._name, - body=self.to_dict(), - **kwargs, - ) + return es.indices.create(index=self._name, body=self.to_dict(), **kwargs) def is_closed(self, using=None): es = get_connection(using) @@ -366,7 +360,7 @@ def analyze(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.analyze") - return es.indices.analyze(index=self._index, **kwargs) + return es.indices.analyze(index=self._name, **kwargs) def refresh(self, using=None, **kwargs): """ @@ -378,7 +372,7 @@ def refresh(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.refresh") - return es.indices.refresh(index=self._index, **kwargs) + return es.indices.refresh(index=self._name, **kwargs) def flush(self, using=None, **kwargs): """ @@ -390,7 +384,7 @@ def flush(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.flush") - return es.indices.flush(index=self._index, **kwargs) + return es.indices.flush(index=self._name, **kwargs) def get(self, using=None, **kwargs): """ @@ -402,7 +396,7 @@ def get(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get") - return es.indices.get(index=self._index, **kwargs) + return es.indices.get(index=self._name, **kwargs) def open(self, using=None, **kwargs): """ @@ -414,7 +408,7 @@ def open(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.open") - return es.indices.open(index=self._index, **kwargs) + return es.indices.open(index=self._name, **kwargs) def close(self, using=None, **kwargs): """ @@ -426,7 +420,7 @@ def close(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.close") - return es.indices.close(index=self._index, **kwargs) + return es.indices.close(index=self._name, **kwargs) def delete(self, using=None, **kwargs): """ @@ -438,7 +432,7 @@ def delete(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.delete") - return es.indices.delete(index=self._index, **kwargs) + return es.indices.delete(index=self._name, **kwargs) def exists(self, using=None, **kwargs): """ @@ -450,7 +444,7 @@ def exists(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.exists") - return es.indices.exists(index=self._index, **kwargs) + return es.indices.exists(index=self._name, **kwargs) def exists_type(self, using=None, **kwargs): """ @@ -462,7 +456,7 @@ def exists_type(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.exists_type") - return es.indices.exists_type(index=self._index, **kwargs) + return es.indices.exists_type(index=self._name, **kwargs) def put_mapping(self, using=None, **kwargs): """ @@ -474,7 +468,7 @@ def put_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.put_mapping") - return es.indices.put_mapping(index=self._index, **kwargs) + return es.indices.put_mapping(index=self._name, **kwargs) def get_mapping(self, using=None, **kwargs): """ @@ -486,7 +480,7 @@ def get_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get_mapping") - return es.indices.get_mapping(index=self._index, **kwargs) + return es.indices.get_mapping(index=self._name, **kwargs) def get_field_mapping(self, using=None, **kwargs): """ @@ -498,7 +492,7 @@ def get_field_mapping(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get_field_mapping") - return es.indices.get_field_mapping(index=self._index, **kwargs) + return es.indices.get_field_mapping(index=self._name, **kwargs) def put_alias(self, using=None, **kwargs): """ @@ -510,7 +504,7 @@ def put_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.put_alias") - return es.indices.put_alias(index=self._index, **kwargs) + return es.indices.put_alias(index=self._name, **kwargs) def exists_alias(self, using=None, **kwargs): """ @@ -533,7 +527,7 @@ def get_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get_alias") - return es.indices.get_alias(index=self._index, **kwargs) + return es.indices.get_alias(index=self._name, **kwargs) def delete_alias(self, using=None, **kwargs): """ @@ -545,7 +539,7 @@ def delete_alias(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.delete_alias") - return es.indices.delete_alias(index=self._index, **kwargs) + return es.indices.delete_alias(index=self._name, **kwargs) def get_settings(self, using=None, **kwargs): """ @@ -557,7 +551,7 @@ def get_settings(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get_settings") - return es.indices.get_settings(index=self._index, **kwargs) + return es.indices.get_settings(index=self._name, **kwargs) def put_settings(self, using=None, **kwargs): """ @@ -569,7 +563,7 @@ def put_settings(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.put_settings") - return es.indices.put_settings(index=self._index, **kwargs) + return es.indices.put_settings(index=self._name, **kwargs) def stats(self, using=None, **kwargs): """ @@ -581,7 +575,7 @@ def stats(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.stats") - return es.indices.stats(index=self._index, **kwargs) + return es.indices.stats(index=self._name, **kwargs) def segments(self, using=None, **kwargs): """ @@ -594,7 +588,7 @@ def segments(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.segments") - return es.indices.segments(index=self._index, **kwargs) + return es.indices.segments(index=self._name, **kwargs) def validate_query(self, using=None, **kwargs): """ @@ -606,7 +600,7 @@ def validate_query(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.validate_query") - return es.indices.validate_query(index=self._index, **kwargs) + return es.indices.validate_query(index=self._name, **kwargs) def clear_cache(self, using=None, **kwargs): """ @@ -618,7 +612,7 @@ def clear_cache(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.clear_cache") - return es.indices.clear_cache(index=self._index, **kwargs) + return es.indices.clear_cache(index=self._name, **kwargs) def recovery(self, using=None, **kwargs): """ @@ -631,7 +625,7 @@ def recovery(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.recovery") - return es.indices.recovery(index=self._index, **kwargs) + return es.indices.recovery(index=self._name, **kwargs) def upgrade(self, using=None, **kwargs): """ @@ -643,7 +637,7 @@ def upgrade(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.upgrade") - return es.indices.upgrade(index=self._index, **kwargs) + return es.indices.upgrade(index=self._name, **kwargs) def get_upgrade(self, using=None, **kwargs): """ @@ -655,7 +649,7 @@ def get_upgrade(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.get_upgrade") - return es.indices.get_upgrade(index=self._index, **kwargs) + return es.indices.get_upgrade(index=self._name, **kwargs) def flush_synced(self, using=None, **kwargs): """ @@ -668,7 +662,7 @@ def flush_synced(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.flush_synced") - return es.indices.flush_synced(index=self._index, **kwargs) + return es.indices.flush_synced(index=self._name, **kwargs) def shard_stores(self, using=None, **kwargs): """ @@ -683,7 +677,7 @@ def shard_stores(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.shard_stores") - return es.indices.shard_stores(index=self._index, **kwargs) + return es.indices.shard_stores(index=self._name, **kwargs) def forcemerge(self, using=None, **kwargs): """ @@ -702,7 +696,7 @@ def forcemerge(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.forcemerge") - return es.indices.forcemerge(index=self._index, **kwargs) + return es.indices.forcemerge(index=self._name, **kwargs) def shrink(self, using=None, **kwargs): """ @@ -722,4 +716,4 @@ def shrink(self, using=None, **kwargs): es = self._get_connection(using) ensure_sync_connection(es, "Index.shrink") - return es.indices.shrink(index=self._index, **kwargs) + return es.indices.shrink(index=self._name, **kwargs) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 29e1f31fb..ab785d02d 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -734,9 +734,7 @@ def scan(self): es = get_connection(self._using) ensure_sync_connection(es, "Search.scan") - for hit in scan( - es, query=self.to_dict(), index=self._index, **self._params - ): + for hit in scan(es, query=self.to_dict(), index=self._index, **self._params): yield self._get_result(hit) def delete(self): @@ -747,9 +745,7 @@ def delete(self): ensure_sync_connection(es, "Search.delete") return AttrDict( - es.delete_by_query( - index=self._index, body=self.to_dict(), **self._params - ) + es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) ) @@ -808,9 +804,7 @@ def execute(self, ignore_cache=False, raise_on_error=True): ensure_sync_connection(es, "MultiSearch.execute") responses = es.msearch( - index=self._index, - body=self.to_dict(), - **self.params, + index=self._index, body=self.to_dict(), **self.params ) out = [] diff --git a/elasticsearch_dsl/update_by_query.py b/elasticsearch_dsl/update_by_query.py index 3c5a4f943..e0de1f13a 100644 --- a/elasticsearch_dsl/update_by_query.py +++ b/elasticsearch_dsl/update_by_query.py @@ -154,12 +154,10 @@ def execute(self): the data. """ es = get_connection(self._using) - ensure_sync_connection(es, "SyncMultiSearch.execute") + ensure_sync_connection(es, "MultiSearch.execute") self._response = self._response_class( self, - es.update_by_query( - index=self._index, body=self.to_dict(), **self._params - ), + es.update_by_query(index=self._index, body=self.to_dict(), **self._params), ) return self._response diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 43f33d597..f353f05f5 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -546,10 +546,15 @@ def full_clean(self): def ensure_sync_connection(es, fn_label): + # Allow "Mock" objects to be passed during testing. + if es.__class__.__name__ == "Mock": + return + if not isinstance(es, Elasticsearch): raise TypeError( - f"{fn_label} can only be used with the elasticsearch.Elasticsearch " - "client" + "{} can only be used with the elasticsearch.Elasticsearch client".format( + fn_label, + ) ) diff --git a/utils/generate-sync.py b/utils/generate-sync.py index ff40b5efe..df8bf2061 100644 --- a/utils/generate-sync.py +++ b/utils/generate-sync.py @@ -1,11 +1,36 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os from pathlib import Path +import black import unasync +from click.testing import CliRunner CODE_ROOT = Path(__file__).absolute().parent.parent +def _blacken(filename): + runner = CliRunner() + result = runner.invoke(black.main, [str(filename)]) + assert result.exit_code == 0, result.output + + def generate_sync(): additional_replacements = { "_async": "", @@ -36,7 +61,8 @@ def generate_sync(): filepaths.append(os.path.join(root, filename)) unasync.unasync_files(filepaths, rules) + _blacken(CODE_ROOT / "elasticsearch_dsl") -if __name__ == '__main__': +if __name__ == "__main__": generate_sync() From 22351eee83056699cd058278c89128cf06ec1397 Mon Sep 17 00:00:00 2001 From: James Brewer Date: Fri, 23 Oct 2020 10:58:17 -0600 Subject: [PATCH 11/11] Fix broken build ... but with integration tests this time --- elasticsearch_dsl/_async/index.py | 4 ++-- elasticsearch_dsl/_async/search.py | 2 +- elasticsearch_dsl/index.py | 4 ++-- elasticsearch_dsl/search.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index df5c93c9e..0aa61701a 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -281,13 +281,13 @@ async def create(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.create`` unchanged. """ - es = get_connection(using) + es = self._get_connection(using) ensure_async_connection(es, "Index.create") return await es.indices.create(index=self._name, body=self.to_dict(), **kwargs) async def is_closed(self, using=None): - es = get_connection(using) + es = self._get_connection(using) ensure_async_connection(es, "Index.is_closed") state = await es.cluster.state( diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index ca7073973..87f04ea06 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -808,7 +808,7 @@ async def execute(self, ignore_cache=False, raise_on_error=True): ensure_async_connection(es, "MultiSearch.execute") responses = await es.msearch( - index=self._index, body=self.to_dict(), **self.params + index=self._index, body=self.to_dict(), **self._params ) out = [] diff --git a/elasticsearch_dsl/index.py b/elasticsearch_dsl/index.py index 04bf77478..6fde05247 100644 --- a/elasticsearch_dsl/index.py +++ b/elasticsearch_dsl/index.py @@ -279,13 +279,13 @@ def create(self, using=None, **kwargs): Any additional keyword arguments will be passed to ``Elasticsearch.indices.create`` unchanged. """ - es = get_connection(using) + es = self._get_connection(using) ensure_sync_connection(es, "Index.create") return es.indices.create(index=self._name, body=self.to_dict(), **kwargs) def is_closed(self, using=None): - es = get_connection(using) + es = self._get_connection(using) ensure_sync_connection(es, "Index.is_closed") state = es.cluster.state( diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index ab785d02d..b73ceb629 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -804,7 +804,7 @@ def execute(self, ignore_cache=False, raise_on_error=True): ensure_sync_connection(es, "MultiSearch.execute") responses = es.msearch( - index=self._index, body=self.to_dict(), **self.params + index=self._index, body=self.to_dict(), **self._params ) out = []