From 0fb40d026b374eef877adfdb51050d84593d5bd1 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 20 May 2024 17:02:29 +0100 Subject: [PATCH 1/2] Added Response.search_after() method --- elasticsearch_dsl/response/__init__.py | 32 ++++++++++++++ elasticsearch_dsl/search_base.py | 30 +++++++++++++ tests/test_integration/_async/test_search.py | 46 ++++++++++++++++++++ tests/test_integration/_sync/test_search.py | 46 ++++++++++++++++++++ 4 files changed, 154 insertions(+) diff --git a/elasticsearch_dsl/response/__init__.py b/elasticsearch_dsl/response/__init__.py index 7af054b5b..1fd64fb9a 100644 --- a/elasticsearch_dsl/response/__init__.py +++ b/elasticsearch_dsl/response/__init__.py @@ -90,6 +90,38 @@ def aggs(self): super(AttrDict, self).__setattr__("_aggs", aggs) return self._aggs + def search_after(self): + """ + Return a ``Search`` instance that retrieves the next page of results. + + This method provides an easy way to paginate a long list of results using + the ``search_after`` option. For example:: + + page_size = 20 + s = Search()[:page_size].sort("date") + + while True: + # get a page of results + r = await s.execute() + + # do something with this page of results + + # exit the loop if we reached the end + if len(r.hits) < page_size: + break + + # get a search object with the next page of results + s = r.search_after() + + Note that the ``search_after`` option requires the search to have an + explicit ``sort`` order. + """ + if len(self.hits) == 0: + raise ValueError("Cannot use search_after when there are no search_results") + if not hasattr(self.hits[-1].meta, "sort"): + raise ValueError("Cannot use search_after when results are not sorted") + return self._search.extra(search_after=self.hits[-1].meta.sort) + class AggResponse(AttrDict): def __init__(self, aggs, search, data): diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index d54b6b925..5680778cb 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -760,6 +760,36 @@ def suggest(self, name, text, **kwargs): s._suggest[name].update(kwargs) return s + def search_after(self): + """ + Return a ``Search`` instance that retrieves the next page of results. + + This method provides an easy way to paginate a long list of results using + the ``search_after`` option. For example:: + + page_size = 20 + s = Search()[:page_size].sort("date") + + while True: + # get a page of results + r = await s.execute() + + # do something with this page of results + + # exit the loop if we reached the end + if len(r.hits) < page_size: + break + + # get a search object with the next page of results + s = s.search_after() + + Note that the ``search_after`` option requires the search to have an + explicit ``sort`` order. + """ + if not hasattr(self, "_response"): + raise ValueError("A search must be executed before using search_after") + return self._response.search_after() + def to_dict(self, count=False, **kwargs): """ Serialize the search into the dictionary that will be sent over as the diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 7fc56c870..f12cf6794 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -125,6 +125,52 @@ async def test_scan_iterates_through_all_docs(async_data_client): assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} +@pytest.mark.asyncio +async def test_search_after(async_data_client): + page_size = 7 + s = AsyncSearch(index="flat-git")[:page_size].sort("authored_date") + commits = [] + while True: + r = await s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.asyncio +async def test_search_after_no_search(async_data_client): + s = AsyncSearch(index="flat-git") + with raises(ValueError): + await s.search_after() + await s.count() + with raises(ValueError): + await s.search_after() + + +@pytest.mark.asyncio +async def test_search_after_no_sort(async_data_client): + s = AsyncSearch(index="flat-git") + r = await s.execute() + with raises(ValueError): + await r.search_after() + + +@pytest.mark.asyncio +async def test_search_after_no_results(async_data_client): + s = AsyncSearch(index="flat-git")[:100].sort("authored_date") + r = await s.execute() + assert 52 == len(r.hits) + s = r.search_after() + r = await s.execute() + assert 0 == len(r.hits) + with raises(ValueError): + await r.search_after() + + @pytest.mark.asyncio async def test_response_is_cached(async_data_client): s = Repository.search() diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index b31ef8b3d..1a3070c8e 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -117,6 +117,52 @@ def test_scan_iterates_through_all_docs(data_client): assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} +@pytest.mark.sync +def test_search_after(data_client): + page_size = 7 + s = Search(index="flat-git")[:page_size].sort("authored_date") + commits = [] + while True: + r = s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.sync +def test_search_after_no_search(data_client): + s = Search(index="flat-git") + with raises(ValueError): + s.search_after() + s.count() + with raises(ValueError): + s.search_after() + + +@pytest.mark.sync +def test_search_after_no_sort(data_client): + s = Search(index="flat-git") + r = s.execute() + with raises(ValueError): + r.search_after() + + +@pytest.mark.sync +def test_search_after_no_results(data_client): + s = Search(index="flat-git")[:100].sort("authored_date") + r = s.execute() + assert 52 == len(r.hits) + s = r.search_after() + r = s.execute() + assert 0 == len(r.hits) + with raises(ValueError): + r.search_after() + + @pytest.mark.sync def test_response_is_cached(data_client): s = Repository.search() From 7d777e0df57049ac8ebd019163bda74821d05b8a Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Tue, 21 May 2024 14:58:45 +0100 Subject: [PATCH 2/2] add match clause to pytest.raises --- elasticsearch_dsl/response/__init__.py | 2 +- tests/test_integration/_async/test_search.py | 16 ++++++++++++---- tests/test_integration/_sync/test_search.py | 16 ++++++++++++---- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/elasticsearch_dsl/response/__init__.py b/elasticsearch_dsl/response/__init__.py index 1fd64fb9a..482400a4d 100644 --- a/elasticsearch_dsl/response/__init__.py +++ b/elasticsearch_dsl/response/__init__.py @@ -117,7 +117,7 @@ def search_after(self): explicit ``sort`` order. """ if len(self.hits) == 0: - raise ValueError("Cannot use search_after when there are no search_results") + raise ValueError("Cannot use search_after when there are no search results") if not hasattr(self.hits[-1].meta, "sort"): raise ValueError("Cannot use search_after when results are not sorted") return self._search.extra(search_after=self.hits[-1].meta.sort) diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index f12cf6794..6d6a5ab98 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -144,10 +144,14 @@ async def test_search_after(async_data_client): @pytest.mark.asyncio async def test_search_after_no_search(async_data_client): s = AsyncSearch(index="flat-git") - with raises(ValueError): + with raises( + ValueError, match="A search must be executed before using search_after" + ): await s.search_after() await s.count() - with raises(ValueError): + with raises( + ValueError, match="A search must be executed before using search_after" + ): await s.search_after() @@ -155,7 +159,9 @@ async def test_search_after_no_search(async_data_client): async def test_search_after_no_sort(async_data_client): s = AsyncSearch(index="flat-git") r = await s.execute() - with raises(ValueError): + with raises( + ValueError, match="Cannot use search_after when results are not sorted" + ): await r.search_after() @@ -167,7 +173,9 @@ async def test_search_after_no_results(async_data_client): s = r.search_after() r = await s.execute() assert 0 == len(r.hits) - with raises(ValueError): + with raises( + ValueError, match="Cannot use search_after when there are no search results" + ): await r.search_after() diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index 1a3070c8e..09c318369 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -136,10 +136,14 @@ def test_search_after(data_client): @pytest.mark.sync def test_search_after_no_search(data_client): s = Search(index="flat-git") - with raises(ValueError): + with raises( + ValueError, match="A search must be executed before using search_after" + ): s.search_after() s.count() - with raises(ValueError): + with raises( + ValueError, match="A search must be executed before using search_after" + ): s.search_after() @@ -147,7 +151,9 @@ def test_search_after_no_search(data_client): def test_search_after_no_sort(data_client): s = Search(index="flat-git") r = s.execute() - with raises(ValueError): + with raises( + ValueError, match="Cannot use search_after when results are not sorted" + ): r.search_after() @@ -159,7 +165,9 @@ def test_search_after_no_results(data_client): s = r.search_after() r = s.execute() assert 0 == len(r.hits) - with raises(ValueError): + with raises( + ValueError, match="Cannot use search_after when there are no search results" + ): r.search_after()