From f6df2fd2a3dbfa6e08b41e68790ab330a8baf046 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 9 May 2024 12:11:54 +0200 Subject: [PATCH 1/6] refactor: add type hints to query.py + type_checking in CI chore: fix linting --- elasticsearch_dsl/query.py | 82 ++++++++++++++++++++++++++++---------- elasticsearch_dsl/utils.py | 9 +++-- noxfile.py | 31 +++++++++++++- 3 files changed, 96 insertions(+), 26 deletions(-) diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index 0b5be820..3e85d76e 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -16,7 +16,9 @@ # under the License. import collections.abc +from copy import deepcopy from itertools import chain +from typing import Any, Callable, ClassVar, Optional, Protocol, TypeVar, Union, overload # 'SF' looks unused but the test suite assumes it's available # from this module so others are liable to do so as well. @@ -24,10 +26,41 @@ from .function import ScoreFunction from .utils import DslBase +_T = TypeVar("_T") +_M = TypeVar("_M", bound=collections.abc.Mapping[str, Any]) -def Q(name_or_query="match_all", **params): + +class QProxiedProtocol(Protocol[_T]): + _proxied: _T + + +@overload +def Q(name_or_query: collections.abc.MutableMapping[str, _M]) -> "Query": ... + + +@overload +def Q(name_or_query: "Query") -> "Query": ... + + +@overload +def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ... + + +@overload +def Q(name_or_query: str, **params: Any) -> "Query": ... + + +def Q( + name_or_query: Union[ + str, + "Query", + QProxiedProtocol[_T], + collections.abc.MutableMapping[str, _M], + ] = "match_all", + **params: Any, +) -> Union["Query", _T]: # {"match": {"title": "python"}} - if isinstance(name_or_query, collections.abc.Mapping): + if isinstance(name_or_query, collections.abc.MutableMapping): if params: raise ValueError("Q() cannot accept parameters when passing in a dict.") if len(name_or_query) != 1: @@ -35,8 +68,8 @@ def Q(name_or_query="match_all", **params): 'Q() can only accept dict with a single query ({"match": {...}}). ' "Instead it got (%r)" % name_or_query ) - name, params = name_or_query.copy().popitem() - return Query.get_dsl_class(name)(_expand__to_dot=False, **params) + name, q_params = deepcopy(name_or_query).popitem() + return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params) # MatchAll() if isinstance(name_or_query, Query): @@ -57,26 +90,31 @@ def Q(name_or_query="match_all", **params): class Query(DslBase): _type_name = "query" _type_shortcut = staticmethod(Q) - name = None + name: ClassVar[Optional[str]] = None + + # Add type annotations for methods not defined in every subclass + __ror__: ClassVar[Callable[["Query", "Query"], "Query"]] + __radd__: ClassVar[Callable[["Query", "Query"], "Query"]] + __rand__: ClassVar[Callable[["Query", "Query"], "Query"]] - def __add__(self, other): + def __add__(self, other: "Query") -> "Query": # make sure we give queries that know how to combine themselves # preference if hasattr(other, "__radd__"): return other.__radd__(self) return Bool(must=[self, other]) - def __invert__(self): + def __invert__(self) -> "Query": return Bool(must_not=[self]) - def __or__(self, other): + def __or__(self, other: "Query") -> "Query": # make sure we give queries that know how to combine themselves # preference if hasattr(other, "__ror__"): return other.__ror__(self) return Bool(should=[self, other]) - def __and__(self, other): + def __and__(self, other: "Query") -> "Query": # make sure we give queries that know how to combine themselves # preference if hasattr(other, "__rand__"): @@ -87,17 +125,17 @@ def __and__(self, other): class MatchAll(Query): name = "match_all" - def __add__(self, other): + def __add__(self, other: "Query") -> "Query": return other._clone() __and__ = __rand__ = __radd__ = __add__ - def __or__(self, other): + def __or__(self, other: "Query") -> "MatchAll": return self __ror__ = __or__ - def __invert__(self): + def __invert__(self) -> "MatchNone": return MatchNone() @@ -107,17 +145,17 @@ def __invert__(self): class MatchNone(Query): name = "match_none" - def __add__(self, other): + def __add__(self, other: "Query") -> "MatchNone": return self __and__ = __rand__ = __radd__ = __add__ - def __or__(self, other): + def __or__(self, other: "Query") -> "Query": return other._clone() __ror__ = __or__ - def __invert__(self): + def __invert__(self) -> MatchAll: return MatchAll() @@ -130,7 +168,7 @@ class Bool(Query): "filter": {"type": "query", "multi": True}, } - def __add__(self, other): + def __add__(self, other: Query) -> "Bool": q = self._clone() if isinstance(other, Bool): q.must += other.must @@ -143,7 +181,7 @@ def __add__(self, other): __radd__ = __add__ - def __or__(self, other): + def __or__(self, other: Query) -> Query: for q in (self, other): if isinstance(q, Bool) and not any( (q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None)) @@ -168,20 +206,20 @@ def __or__(self, other): __ror__ = __or__ @property - def _min_should_match(self): + def _min_should_match(self) -> int: return getattr( self, "minimum_should_match", 0 if not self.should or (self.must or self.filter) else 1, ) - def __invert__(self): + def __invert__(self) -> Query: # Because an empty Bool query is treated like # MatchAll the inverse should be MatchNone if not any(chain(self.must, self.filter, self.should, self.must_not)): return MatchNone() - negations = [] + negations: list[Query] = [] for q in chain(self.must, self.filter): negations.append(~q) @@ -195,7 +233,7 @@ def __invert__(self): return negations[0] return Bool(should=negations) - def __and__(self, other): + def __and__(self, other: Query) -> Query: q = self._clone() if isinstance(other, Bool): q.must += other.must @@ -247,7 +285,7 @@ class FunctionScore(Query): "functions": {"type": "score_function", "multi": True}, } - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): if "functions" in kwargs: pass else: diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 51dbe27d..41d17482 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,6 +18,7 @@ import collections.abc from copy import copy +from typing import Any, Optional, Self from .exceptions import UnknownDslObject, ValidationException @@ -251,7 +252,9 @@ class DslBase(metaclass=DslMeta): _param_defs = {} @classmethod - def get_dsl_class(cls, name, default=None): + def get_dsl_class( + cls: type[Self], name: str, default: Optional[str] = None + ) -> type[Self]: try: return cls._classes[name] except KeyError: @@ -261,7 +264,7 @@ def get_dsl_class(cls, name, default=None): f"DSL class `{name}` does not exist in {cls._type_name}." ) - def __init__(self, _expand__to_dot=None, **params): + def __init__(self, _expand__to_dot: Optional[bool] = None, **params: Any) -> None: if _expand__to_dot is None: _expand__to_dot = EXPAND__TO_DOT self._params = {} @@ -390,7 +393,7 @@ def to_dict(self): d[pname] = value return {self.name: d} - def _clone(self): + def _clone(self) -> Self: c = self.__class__() for attr in self._params: c._params[attr] = copy(self._params[attr]) diff --git a/noxfile.py b/noxfile.py index 42548199..f800ca27 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import subprocess + import nox SOURCE_FILES = ( @@ -27,6 +29,8 @@ "utils/", ) +TYPED_FILES = ("elasticsearch_dsl/query.py",) + @nox.session( python=[ @@ -72,10 +76,35 @@ def lint(session): session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES) session.run("python", "utils/run-unasync.py", "--check") - session.run("flake8", "--ignore=E501,E741,W503", *SOURCE_FILES) + session.run("flake8", "--ignore=E501,E741,W503,E704", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) +@nox.session(python="3.12") +def type_check(session): + session.install("mypy", ".[develop]") + errors = [] + popen = subprocess.Popen( + "mypy --strict elasticsearch_dsl", + env=session.env, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + mypy_output = "" + while popen.poll() is None: + mypy_output += popen.stdout.read(8192).decode() + mypy_output += popen.stdout.read().decode() + + for line in mypy_output.split("\n"): + filepath = line.partition(":")[0] + if filepath in TYPED_FILES: + errors.append(line) + if errors: + session.error("\n" + "\n".join(sorted(set(errors)))) + + @nox.session() def docs(session): session.install(".[develop]") From 50162021458026540de0fa7e48a5bc6ef7d8ca0b Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 9 May 2024 13:06:53 +0200 Subject: [PATCH 2/6] fix: fix typing for older versions of python --- elasticsearch_dsl/query.py | 19 +++++++++++++++---- elasticsearch_dsl/utils.py | 8 +++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index 3e85d76e..a2b6643b 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -18,7 +18,18 @@ import collections.abc from copy import deepcopy from itertools import chain -from typing import Any, Callable, ClassVar, Optional, Protocol, TypeVar, Union, overload +from typing import ( + Any, + Callable, + ClassVar, + Mapping, + MutableMapping, + Optional, + Protocol, + TypeVar, + Union, + overload, +) # 'SF' looks unused but the test suite assumes it's available # from this module so others are liable to do so as well. @@ -27,7 +38,7 @@ from .utils import DslBase _T = TypeVar("_T") -_M = TypeVar("_M", bound=collections.abc.Mapping[str, Any]) +_M = TypeVar("_M", bound=Mapping[str, Any]) class QProxiedProtocol(Protocol[_T]): @@ -35,7 +46,7 @@ class QProxiedProtocol(Protocol[_T]): @overload -def Q(name_or_query: collections.abc.MutableMapping[str, _M]) -> "Query": ... +def Q(name_or_query: MutableMapping[str, _M]) -> "Query": ... @overload @@ -55,7 +66,7 @@ def Q( str, "Query", QProxiedProtocol[_T], - collections.abc.MutableMapping[str, _M], + MutableMapping[str, _M], ] = "match_all", **params: Any, ) -> Union["Query", _T]: diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 41d17482..e0ba0a1f 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,7 +18,9 @@ import collections.abc from copy import copy -from typing import Any, Optional, Self +from typing import Any, Optional, Type + +from typing_extensions import Self from .exceptions import UnknownDslObject, ValidationException @@ -253,8 +255,8 @@ class DslBase(metaclass=DslMeta): @classmethod def get_dsl_class( - cls: type[Self], name: str, default: Optional[str] = None - ) -> type[Self]: + cls: Type[Self], name: str, default: Optional[str] = None + ) -> Type[Self]: try: return cls._classes[name] except KeyError: From 3c6b95bcc1d4fba719a2fc746593666262ed5aef Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Sat, 11 May 2024 16:49:04 +0200 Subject: [PATCH 3/6] refactor: add typing to query tests --- elasticsearch_dsl/function.py | 6 +- elasticsearch_dsl/query.py | 5 +- elasticsearch_dsl/utils.py | 5 +- mypy.ini | 3 + noxfile.py | 7 +- tests/test_query.py | 132 +++++++++++++++++----------------- 6 files changed, 85 insertions(+), 73 deletions(-) create mode 100644 mypy.ini diff --git a/elasticsearch_dsl/function.py b/elasticsearch_dsl/function.py index 3be9bd81..ef77ce8e 100644 --- a/elasticsearch_dsl/function.py +++ b/elasticsearch_dsl/function.py @@ -16,11 +16,13 @@ # under the License. import collections.abc +from typing import Dict from .utils import DslBase -def SF(name_or_sf, **params): +# Incomplete annotation to not break query.py tests +def SF(name_or_sf, **params) -> "ScoreFunction": # {"script_score": {"script": "_score"}, "filter": {}} if isinstance(name_or_sf, collections.abc.Mapping): if params: @@ -86,7 +88,7 @@ class ScriptScore(ScoreFunction): class BoostFactor(ScoreFunction): name = "boost_factor" - def to_dict(self): + def to_dict(self) -> Dict[str, int]: d = super().to_dict() if "value" in d[self.name]: d[self.name] = d[self.name].pop("value") diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index a2b6643b..d014a224 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -28,6 +28,7 @@ Protocol, TypeVar, Union, + cast, overload, ) @@ -58,7 +59,7 @@ def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ... @overload -def Q(name_or_query: str, **params: Any) -> "Query": ... +def Q(name_or_query: str = "match_all", **params: Any) -> "Query": ... def Q( @@ -92,7 +93,7 @@ def Q( # s.query = Q('filtered', query=s.query) if hasattr(name_or_query, "_proxied"): - return name_or_query._proxied + return cast(QProxiedProtocol[_T], name_or_query)._proxied # "match", title="python" return Query.get_dsl_class(name_or_query)(**params) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index e0ba0a1f..da6d4fa7 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,7 +18,7 @@ import collections.abc from copy import copy -from typing import Any, Optional, Type +from typing import Any, Dict, Optional, Type from typing_extensions import Self @@ -356,7 +356,8 @@ def __getattr__(self, name): return AttrDict(value) return value - def to_dict(self): + # TODO: This type annotation can probably be made tighter + def to_dict(self) -> Dict[str, Dict[str, Any]]: """ Serialize the DSL object to plain dict """ diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..0c795321 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy-elasticsearch_dsl.query] +# Allow reexport of SF for tests +implicit_reexport = True \ No newline at end of file diff --git a/noxfile.py b/noxfile.py index f800ca27..fb6d8c4c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,7 +29,10 @@ "utils/", ) -TYPED_FILES = ("elasticsearch_dsl/query.py",) +TYPED_FILES = ( + "elasticsearch_dsl/query.py", + "tests/test_query.py", +) @nox.session( @@ -85,7 +88,7 @@ def type_check(session): session.install("mypy", ".[develop]") errors = [] popen = subprocess.Popen( - "mypy --strict elasticsearch_dsl", + "mypy --strict elasticsearch_dsl tests", env=session.env, shell=True, stdout=subprocess.PIPE, diff --git a/tests/test_query.py b/tests/test_query.py index fc4e430b..601f5102 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -20,14 +20,14 @@ from elasticsearch_dsl import function, query, utils -def test_empty_Q_is_match_all(): +def test_empty_Q_is_match_all() -> None: q = query.Q() assert isinstance(q, query.MatchAll) assert query.MatchAll() == q -def test_combined_fields_to_dict(): +def test_combined_fields_to_dict() -> None: assert { "combined_fields": { "query": "this is a test", @@ -41,7 +41,7 @@ def test_combined_fields_to_dict(): ).to_dict() -def test_combined_fields_to_dict_extra(): +def test_combined_fields_to_dict_extra() -> None: assert { "combined_fields": { "query": "this is a test", @@ -55,54 +55,54 @@ def test_combined_fields_to_dict_extra(): ).to_dict() -def test_match_to_dict(): +def test_match_to_dict() -> None: assert {"match": {"f": "value"}} == query.Match(f="value").to_dict() -def test_match_to_dict_extra(): +def test_match_to_dict_extra() -> None: assert {"match": {"f": "value", "boost": 2}} == query.Match( f="value", boost=2 ).to_dict() -def test_fuzzy_to_dict(): +def test_fuzzy_to_dict() -> None: assert {"fuzzy": {"f": "value"}} == query.Fuzzy(f="value").to_dict() -def test_prefix_to_dict(): +def test_prefix_to_dict() -> None: assert {"prefix": {"f": "value"}} == query.Prefix(f="value").to_dict() -def test_term_to_dict(): +def test_term_to_dict() -> None: assert {"term": {"_type": "article"}} == query.Term(_type="article").to_dict() -def test_bool_to_dict(): +def test_bool_to_dict() -> None: bool = query.Bool(must=[query.Match(f="value")], should=[]) assert {"bool": {"must": [{"match": {"f": "value"}}]}} == bool.to_dict() -def test_dismax_to_dict(): +def test_dismax_to_dict() -> None: assert {"dis_max": {"queries": [{"term": {"_type": "article"}}]}} == query.DisMax( queries=[query.Term(_type="article")] ).to_dict() -def test_bool_from_dict_issue_318(): +def test_bool_from_dict_issue_318() -> None: d = {"bool": {"must_not": {"match": {"field": "value"}}}} q = query.Q(d) assert q == ~query.Match(field="value") -def test_repr(): +def test_repr() -> None: bool = query.Bool(must=[query.Match(f="value")], should=[]) assert "Bool(must=[Match(f='value')])" == repr(bool) -def test_query_clone(): +def test_query_clone() -> None: bool = query.Bool( must=[query.Match(x=42)], should=[query.Match(g="v2")], @@ -114,14 +114,14 @@ def test_query_clone(): assert bool is not bool_clone -def test_bool_converts_its_init_args_to_queries(): +def test_bool_converts_its_init_args_to_queries() -> None: q = query.Bool(must=[{"match": {"f": "value"}}]) assert len(q.must) == 1 assert q.must[0] == query.Match(f="value") -def test_two_queries_make_a_bool(): +def test_two_queries_make_a_bool() -> None: q1 = query.Match(f="value1") q2 = query.Match(message={"query": "this is a test", "opeartor": "and"}) q = q1 & q2 @@ -130,7 +130,7 @@ def test_two_queries_make_a_bool(): assert [q1, q2] == q.must -def test_other_and_bool_appends_other_to_must(): +def test_other_and_bool_appends_other_to_must() -> None: q1 = query.Match(f="value1") qb = query.Bool() @@ -139,7 +139,7 @@ def test_other_and_bool_appends_other_to_must(): assert q.must[0] == q1 -def test_bool_and_other_appends_other_to_must(): +def test_bool_and_other_appends_other_to_must() -> None: q1 = query.Match(f="value1") qb = query.Bool() @@ -148,7 +148,7 @@ def test_bool_and_other_appends_other_to_must(): assert q.must[0] == q1 -def test_bool_and_other_sets_min_should_match_if_needed(): +def test_bool_and_other_sets_min_should_match_if_needed() -> None: q1 = query.Q("term", category=1) q2 = query.Q( "bool", should=[query.Q("term", name="aaa"), query.Q("term", name="bbb")] @@ -162,7 +162,7 @@ def test_bool_and_other_sets_min_should_match_if_needed(): ) -def test_bool_with_different_minimum_should_match_should_not_be_combined(): +def test_bool_with_different_minimum_should_match_should_not_be_combined() -> None: q1 = query.Q( "bool", minimum_should_match=2, @@ -201,11 +201,11 @@ def test_bool_with_different_minimum_should_match_should_not_be_combined(): assert q5 == query.Bool(should=[q1, q2, q3]) -def test_empty_bool_has_min_should_match_0(): +def test_empty_bool_has_min_should_match_0() -> None: assert 0 == query.Bool()._min_should_match -def test_query_and_query_creates_bool(): +def test_query_and_query_creates_bool() -> None: q1 = query.Match(f=42) q2 = query.Match(g=47) @@ -214,7 +214,7 @@ def test_query_and_query_creates_bool(): assert q.must == [q1, q2] -def test_match_all_and_query_equals_other(): +def test_match_all_and_query_equals_other() -> None: q1 = query.Match(f=42) q2 = query.MatchAll() @@ -222,39 +222,39 @@ def test_match_all_and_query_equals_other(): assert q1 == q -def test_not_match_all_is_match_none(): +def test_not_match_all_is_match_none() -> None: q = query.MatchAll() assert ~q == query.MatchNone() -def test_not_match_none_is_match_all(): +def test_not_match_none_is_match_all() -> None: q = query.MatchNone() assert ~q == query.MatchAll() -def test_invert_empty_bool_is_match_none(): +def test_invert_empty_bool_is_match_none() -> None: q = query.Bool() assert ~q == query.MatchNone() -def test_match_none_or_query_equals_query(): +def test_match_none_or_query_equals_query() -> None: q1 = query.Match(f=42) q2 = query.MatchNone() assert q1 | q2 == query.Match(f=42) -def test_match_none_and_query_equals_match_none(): +def test_match_none_and_query_equals_match_none() -> None: q1 = query.Match(f=42) q2 = query.MatchNone() assert q1 & q2 == query.MatchNone() -def test_bool_and_bool(): +def test_bool_and_bool() -> None: qt1, qt2, qt3 = query.Match(f=1), query.Match(f=2), query.Match(f=3) q1 = query.Bool(must=[qt1], should=[qt2]) @@ -270,7 +270,7 @@ def test_bool_and_bool(): ) -def test_bool_and_bool_with_min_should_match(): +def test_bool_and_bool_with_min_should_match() -> None: qt1, qt2 = query.Match(f=1), query.Match(f=2) q1 = query.Q("bool", minimum_should_match=1, should=[qt1]) q2 = query.Q("bool", minimum_should_match=1, should=[qt2]) @@ -278,7 +278,7 @@ def test_bool_and_bool_with_min_should_match(): assert query.Q("bool", must=[qt1, qt2]) == q1 & q2 -def test_negative_min_should_match(): +def test_negative_min_should_match() -> None: qt1, qt2 = query.Match(f=1), query.Match(f=2) q1 = query.Q("bool", minimum_should_match=-2, should=[qt1]) q2 = query.Q("bool", minimum_should_match=1, should=[qt2]) @@ -289,7 +289,7 @@ def test_negative_min_should_match(): q2 & q1 -def test_percentage_min_should_match(): +def test_percentage_min_should_match() -> None: qt1, qt2 = query.Match(f=1), query.Match(f=2) q1 = query.Q("bool", minimum_should_match="50%", should=[qt1]) q2 = query.Q("bool", minimum_should_match=1, should=[qt2]) @@ -300,19 +300,19 @@ def test_percentage_min_should_match(): q2 & q1 -def test_inverted_query_becomes_bool_with_must_not(): +def test_inverted_query_becomes_bool_with_must_not() -> None: q = query.Match(f=42) assert ~q == query.Bool(must_not=[query.Match(f=42)]) -def test_inverted_query_with_must_not_become_should(): +def test_inverted_query_with_must_not_become_should() -> None: q = query.Q("bool", must_not=[query.Q("match", f=1), query.Q("match", f=2)]) assert ~q == query.Q("bool", should=[query.Q("match", f=1), query.Q("match", f=2)]) -def test_inverted_query_with_must_and_must_not(): +def test_inverted_query_with_must_and_must_not() -> None: q = query.Q( "bool", must=[query.Q("match", f=3), query.Q("match", f=4)], @@ -332,13 +332,13 @@ def test_inverted_query_with_must_and_must_not(): ) -def test_double_invert_returns_original_query(): +def test_double_invert_returns_original_query() -> None: q = query.Match(f=42) assert q == ~~q -def test_bool_query_gets_inverted_internally(): +def test_bool_query_gets_inverted_internally() -> None: q = query.Bool(must_not=[query.Match(f=42)], must=[query.Match(g="v")]) assert ~q == query.Bool( @@ -351,7 +351,7 @@ def test_bool_query_gets_inverted_internally(): ) -def test_match_all_or_something_is_match_all(): +def test_match_all_or_something_is_match_all() -> None: q1 = query.MatchAll() q2 = query.Match(f=42) @@ -359,7 +359,7 @@ def test_match_all_or_something_is_match_all(): assert (q2 | q1) == query.MatchAll() -def test_or_produces_bool_with_should(): +def test_or_produces_bool_with_should() -> None: q1 = query.Match(f=42) q2 = query.Match(g="v") @@ -367,7 +367,7 @@ def test_or_produces_bool_with_should(): assert q == query.Bool(should=[q1, q2]) -def test_or_bool_doesnt_loop_infinitely_issue_37(): +def test_or_bool_doesnt_loop_infinitely_issue_37() -> None: q = query.Match(f=42) | ~query.Match(f=47) assert q == query.Bool( @@ -375,7 +375,7 @@ def test_or_bool_doesnt_loop_infinitely_issue_37(): ) -def test_or_bool_doesnt_loop_infinitely_issue_96(): +def test_or_bool_doesnt_loop_infinitely_issue_96() -> None: q = ~query.Match(f=42) | ~query.Match(f=47) assert q == query.Bool( @@ -386,14 +386,14 @@ def test_or_bool_doesnt_loop_infinitely_issue_96(): ) -def test_bool_will_append_another_query_with_or(): +def test_bool_will_append_another_query_with_or() -> None: qb = query.Bool(should=[query.Match(f="v"), query.Match(f="v2")]) q = query.Match(g=42) assert (q | qb) == query.Bool(should=[query.Match(f="v"), query.Match(f="v2"), q]) -def test_bool_queries_with_only_should_get_concatenated(): +def test_bool_queries_with_only_should_get_concatenated() -> None: q1 = query.Bool(should=[query.Match(f=1), query.Match(f=2)]) q2 = query.Bool(should=[query.Match(f=3), query.Match(f=4)]) @@ -402,7 +402,7 @@ def test_bool_queries_with_only_should_get_concatenated(): ) -def test_two_bool_queries_append_one_to_should_if_possible(): +def test_two_bool_queries_append_one_to_should_if_possible() -> None: q1 = query.Bool(should=[query.Match(f="v")]) q2 = query.Bool(must=[query.Match(f="v")]) @@ -414,12 +414,12 @@ def test_two_bool_queries_append_one_to_should_if_possible(): ) -def test_queries_are_registered(): +def test_queries_are_registered() -> None: assert "match" in query.Query._classes assert query.Query._classes["match"] is query.Match -def test_defining_query_registers_it(): +def test_defining_query_registers_it() -> None: class MyQuery(query.Query): name = "my_query" @@ -427,62 +427,64 @@ class MyQuery(query.Query): assert query.Query._classes["my_query"] is MyQuery -def test_Q_passes_query_through(): +def test_Q_passes_query_through() -> None: q = query.Match(f="value1") assert query.Q(q) is q -def test_Q_constructs_query_by_name(): +def test_Q_constructs_query_by_name() -> None: q = query.Q("match", f="value") assert isinstance(q, query.Match) assert {"f": "value"} == q._params -def test_Q_translates_double_underscore_to_dots_in_param_names(): +def test_Q_translates_double_underscore_to_dots_in_param_names() -> None: q = query.Q("match", comment__author="honza") assert {"comment.author": "honza"} == q._params -def test_Q_doesn_translate_double_underscore_to_dots_in_param_names(): +def test_Q_doesn_translate_double_underscore_to_dots_in_param_names() -> None: q = query.Q("match", comment__author="honza", _expand__to_dot=False) assert {"comment__author": "honza"} == q._params -def test_Q_constructs_simple_query_from_dict(): +def test_Q_constructs_simple_query_from_dict() -> None: q = query.Q({"match": {"f": "value"}}) assert isinstance(q, query.Match) assert {"f": "value"} == q._params -def test_Q_constructs_compound_query_from_dict(): +def test_Q_constructs_compound_query_from_dict() -> None: q = query.Q({"bool": {"must": [{"match": {"f": "value"}}]}}) assert q == query.Bool(must=[query.Match(f="value")]) -def test_Q_raises_error_when_passed_in_dict_and_params(): +def test_Q_raises_error_when_passed_in_dict_and_params() -> None: with raises(Exception): - query.Q({"match": {"f": "value"}}, f="value") + # Ignore types as it's not a valid call + query.Q({"match": {"f": "value"}}, f="value") # type: ignore[call-overload] -def test_Q_raises_error_when_passed_in_query_and_params(): +def test_Q_raises_error_when_passed_in_query_and_params() -> None: q = query.Match(f="value1") with raises(Exception): - query.Q(q, f="value") + # Ignore types as it's not a valid call signature + query.Q(q, f="value") # type: ignore[call-overload] -def test_Q_raises_error_on_unknown_query(): +def test_Q_raises_error_on_unknown_query() -> None: with raises(Exception): query.Q("not a query", f="value") -def test_match_all_and_anything_is_anything(): +def test_match_all_and_anything_is_anything() -> None: q = query.MatchAll() s = query.Match(f=42) @@ -490,7 +492,7 @@ def test_match_all_and_anything_is_anything(): assert s & q == s -def test_function_score_with_functions(): +def test_function_score_with_functions() -> None: q = query.Q( "function_score", functions=[query.SF("script_score", script="doc['comment_count'] * _score")], @@ -503,7 +505,7 @@ def test_function_score_with_functions(): } == q.to_dict() -def test_function_score_with_no_function_is_boost_factor(): +def test_function_score_with_no_function_is_boost_factor() -> None: q = query.Q( "function_score", functions=[query.SF({"weight": 20, "filter": query.Q("term", f=42)})], @@ -514,7 +516,7 @@ def test_function_score_with_no_function_is_boost_factor(): } == q.to_dict() -def test_function_score_to_dict(): +def test_function_score_to_dict() -> None: q = query.Q( "function_score", query=query.Q("match", title="python"), @@ -543,7 +545,7 @@ def test_function_score_to_dict(): assert d == q.to_dict() -def test_function_score_with_single_function(): +def test_function_score_with_single_function() -> None: d = { "function_score": { "filter": {"term": {"tags": "python"}}, @@ -561,7 +563,7 @@ def test_function_score_with_single_function(): assert "doc['comment_count'] * _score" == sf.script -def test_function_score_from_dict(): +def test_function_score_from_dict() -> None: d = { "function_score": { "filter": {"term": {"tags": "python"}}, @@ -590,7 +592,7 @@ def test_function_score_from_dict(): assert {"boost_factor": 6} == sf.to_dict() -def test_script_score(): +def test_script_score() -> None: d = { "script_score": { "query": {"match_all": {}}, @@ -605,7 +607,7 @@ def test_script_score(): assert q.to_dict() == d -def test_expand_double_underscore_to_dot_setting(): +def test_expand_double_underscore_to_dot_setting() -> None: q = query.Term(comment__count=2) assert q.to_dict() == {"term": {"comment.count": 2}} utils.EXPAND__TO_DOT = False @@ -614,7 +616,7 @@ def test_expand_double_underscore_to_dot_setting(): utils.EXPAND__TO_DOT = True -def test_knn_query(): +def test_knn_query() -> None: q = query.Knn(field="image-vector", query_vector=[-5, 9, -12], num_candidates=10) assert q.to_dict() == { "knn": { From 65a5785506b823d909fa2c6f92ad1090bc1a50ce Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Sat, 11 May 2024 16:50:20 +0200 Subject: [PATCH 4/6] chore: add type_check to CI --- .github/workflows/ci.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 88edc74c..2a0cd450 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,21 @@ jobs: - name: Lint the code run: nox -s lint + type_check: + runs-on: ubuntu-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python3 -m pip install nox + - name: Lint the code + run: nox -s type_check + docs: runs-on: ubuntu-latest steps: From 824400e9e00ef46f7e69a90bf46b63c7b932ce90 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 16 May 2024 22:24:48 +0200 Subject: [PATCH 5/6] fix: fix typing for older python versions --- elasticsearch_dsl/query.py | 3 ++- noxfile.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index d014a224..8d72ec53 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -22,6 +22,7 @@ Any, Callable, ClassVar, + List, Mapping, MutableMapping, Optional, @@ -231,7 +232,7 @@ def __invert__(self) -> Query: if not any(chain(self.must, self.filter, self.should, self.must_not)): return MatchNone() - negations: list[Query] = [] + negations: List[Query] = [] for q in chain(self.must, self.filter): negations.append(~q) diff --git a/noxfile.py b/noxfile.py index fb6d8c4c..4ebbe717 100644 --- a/noxfile.py +++ b/noxfile.py @@ -83,7 +83,7 @@ def lint(session): session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) -@nox.session(python="3.12") +@nox.session(python="3.8") def type_check(session): session.install("mypy", ".[develop]") errors = [] From 285d9a3e9c0b7d0c40b816c26600915172f88384 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 16 May 2024 22:25:52 +0200 Subject: [PATCH 6/6] fix: fix python version for ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a0cd450..00ecf40a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.12" + python-version: "3.8" - name: Install dependencies run: | python3 -m pip install nox