Skip to content

Commit 0035fe5

Browse files
author
Caio Fontes
committed
refactor: add type hints to query.py + type_checking in CI
1 parent 3e880fd commit 0035fe5

File tree

3 files changed

+104
-25
lines changed

3 files changed

+104
-25
lines changed

elasticsearch_dsl/query.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,69 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from copy import deepcopy
1920
from itertools import chain
21+
from typing import (
22+
Optional,
23+
Any,
24+
overload,
25+
TypeVar,
26+
Protocol,
27+
Callable,
28+
ClassVar,
29+
Union,
30+
)
2031

2132
# 'SF' looks unused but the test suite assumes it's available
2233
# from this module so others are liable to do so as well.
2334
from .function import SF # noqa: F401
2435
from .function import ScoreFunction
2536
from .utils import DslBase
2637

38+
_T = TypeVar("_T")
39+
_M = TypeVar("_M", bound=collections.abc.Mapping[str, Any])
2740

28-
def Q(name_or_query="match_all", **params):
41+
42+
class QProxiedProtocol(Protocol[_T]):
43+
_proxied: _T
44+
45+
46+
@overload
47+
def Q(name_or_query: collections.abc.MutableMapping[str, _M]) -> "Query": ...
48+
49+
50+
@overload
51+
def Q(name_or_query: "Query") -> "Query": ...
52+
53+
54+
@overload
55+
def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...
56+
57+
58+
@overload
59+
def Q(name_or_query: str, **params: Any) -> "Query": ...
60+
61+
62+
def Q(
63+
name_or_query: Union[
64+
str,
65+
"Query",
66+
QProxiedProtocol[_T],
67+
collections.abc.MutableMapping[str, _M],
68+
] = "match_all",
69+
**params: Any
70+
) -> Union["Query", _T]:
2971
# {"match": {"title": "python"}}
30-
if isinstance(name_or_query, collections.abc.Mapping):
72+
if isinstance(name_or_query, collections.abc.MutableMapping):
3173
if params:
3274
raise ValueError("Q() cannot accept parameters when passing in a dict.")
3375
if len(name_or_query) != 1:
3476
raise ValueError(
3577
'Q() can only accept dict with a single query ({"match": {...}}). '
3678
"Instead it got (%r)" % name_or_query
3779
)
38-
name, params = name_or_query.copy().popitem()
39-
return Query.get_dsl_class(name)(_expand__to_dot=False, **params)
80+
name, q_params = deepcopy(name_or_query).popitem()
81+
return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params)
4082

4183
# MatchAll()
4284
if isinstance(name_or_query, Query):
@@ -57,26 +99,31 @@ def Q(name_or_query="match_all", **params):
5799
class Query(DslBase):
58100
_type_name = "query"
59101
_type_shortcut = staticmethod(Q)
60-
name = None
102+
name: ClassVar[Optional[str]] = None
103+
104+
# Add type annotations for methods not defined in every subclass
105+
__ror__: ClassVar[Callable[["Query", "Query"], "Query"]]
106+
__radd__: ClassVar[Callable[["Query", "Query"], "Query"]]
107+
__rand__: ClassVar[Callable[["Query", "Query"], "Query"]]
61108

62-
def __add__(self, other):
109+
def __add__(self, other: "Query") -> "Query":
63110
# make sure we give queries that know how to combine themselves
64111
# preference
65112
if hasattr(other, "__radd__"):
66113
return other.__radd__(self)
67114
return Bool(must=[self, other])
68115

69-
def __invert__(self):
116+
def __invert__(self) -> "Query":
70117
return Bool(must_not=[self])
71118

72-
def __or__(self, other):
119+
def __or__(self, other: "Query") -> "Query":
73120
# make sure we give queries that know how to combine themselves
74121
# preference
75122
if hasattr(other, "__ror__"):
76123
return other.__ror__(self)
77124
return Bool(should=[self, other])
78125

79-
def __and__(self, other):
126+
def __and__(self, other: "Query") -> "Query":
80127
# make sure we give queries that know how to combine themselves
81128
# preference
82129
if hasattr(other, "__rand__"):
@@ -87,17 +134,17 @@ def __and__(self, other):
87134
class MatchAll(Query):
88135
name = "match_all"
89136

90-
def __add__(self, other):
137+
def __add__(self, other: "Query") -> "Query":
91138
return other._clone()
92139

93140
__and__ = __rand__ = __radd__ = __add__
94141

95-
def __or__(self, other):
142+
def __or__(self, other: "Query") -> "MatchAll":
96143
return self
97144

98145
__ror__ = __or__
99146

100-
def __invert__(self):
147+
def __invert__(self) -> "MatchNone":
101148
return MatchNone()
102149

103150

@@ -107,17 +154,17 @@ def __invert__(self):
107154
class MatchNone(Query):
108155
name = "match_none"
109156

110-
def __add__(self, other):
157+
def __add__(self, other: "Query") -> "MatchNone":
111158
return self
112159

113160
__and__ = __rand__ = __radd__ = __add__
114161

115-
def __or__(self, other):
162+
def __or__(self, other: "Query") -> "Query":
116163
return other._clone()
117164

118165
__ror__ = __or__
119166

120-
def __invert__(self):
167+
def __invert__(self) -> MatchAll:
121168
return MatchAll()
122169

123170

@@ -130,7 +177,7 @@ class Bool(Query):
130177
"filter": {"type": "query", "multi": True},
131178
}
132179

133-
def __add__(self, other):
180+
def __add__(self, other: Query) -> "Bool":
134181
q = self._clone()
135182
if isinstance(other, Bool):
136183
q.must += other.must
@@ -143,7 +190,7 @@ def __add__(self, other):
143190

144191
__radd__ = __add__
145192

146-
def __or__(self, other):
193+
def __or__(self, other: Query) -> Query:
147194
for q in (self, other):
148195
if isinstance(q, Bool) and not any(
149196
(q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None))
@@ -168,20 +215,20 @@ def __or__(self, other):
168215
__ror__ = __or__
169216

170217
@property
171-
def _min_should_match(self):
218+
def _min_should_match(self) -> int:
172219
return getattr(
173220
self,
174221
"minimum_should_match",
175222
0 if not self.should or (self.must or self.filter) else 1,
176223
)
177224

178-
def __invert__(self):
225+
def __invert__(self) -> Query:
179226
# Because an empty Bool query is treated like
180227
# MatchAll the inverse should be MatchNone
181228
if not any(chain(self.must, self.filter, self.should, self.must_not)):
182229
return MatchNone()
183230

184-
negations = []
231+
negations: list[Query] = []
185232
for q in chain(self.must, self.filter):
186233
negations.append(~q)
187234

@@ -195,7 +242,7 @@ def __invert__(self):
195242
return negations[0]
196243
return Bool(should=negations)
197244

198-
def __and__(self, other):
245+
def __and__(self, other: Query) -> Query:
199246
q = self._clone()
200247
if isinstance(other, Bool):
201248
q.must += other.must
@@ -247,7 +294,7 @@ class FunctionScore(Query):
247294
"functions": {"type": "score_function", "multi": True},
248295
}
249296

250-
def __init__(self, **kwargs):
297+
def __init__(self, **kwargs: Any):
251298
if "functions" in kwargs:
252299
pass
253300
else:

elasticsearch_dsl/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import collections.abc
2020
from copy import copy
21+
from typing import Optional, Any, Self
2122

2223
from .exceptions import UnknownDslObject, ValidationException
2324

@@ -251,7 +252,7 @@ class DslBase(metaclass=DslMeta):
251252
_param_defs = {}
252253

253254
@classmethod
254-
def get_dsl_class(cls, name, default=None):
255+
def get_dsl_class(cls: type[Self], name: str, default:Optional[str]=None) -> type[Self]:
255256
try:
256257
return cls._classes[name]
257258
except KeyError:
@@ -261,7 +262,7 @@ def get_dsl_class(cls, name, default=None):
261262
f"DSL class `{name}` does not exist in {cls._type_name}."
262263
)
263264

264-
def __init__(self, _expand__to_dot=None, **params):
265+
def __init__(self, _expand__to_dot: Optional[bool]=None, **params: Any) -> None:
265266
if _expand__to_dot is None:
266267
_expand__to_dot = EXPAND__TO_DOT
267268
self._params = {}
@@ -390,7 +391,7 @@ def to_dict(self):
390391
d[pname] = value
391392
return {self.name: d}
392393

393-
def _clone(self):
394+
def _clone(self) -> Self:
394395
c = self.__class__()
395396
for attr in self._params:
396397
c._params[attr] = copy(self._params[attr])

noxfile.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717

1818
import nox
19+
import subprocess
20+
1921

2022
SOURCE_FILES = (
2123
"setup.py",
@@ -27,6 +29,10 @@
2729
"utils/",
2830
)
2931

32+
TYPED_FILES = (
33+
"elasticsearch_dsl/query.py",
34+
)
35+
3036

3137
@nox.session(
3238
python=[
@@ -76,6 +82,31 @@ def lint(session):
7682
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
7783

7884

85+
@nox.session(python="3.12")
86+
def type_check(session):
87+
session.install("mypy", ".[develop]")
88+
errors = []
89+
popen = subprocess.Popen(
90+
"mypy --strict elasticsearch_dsl",
91+
env=session.env,
92+
shell=True,
93+
stdout=subprocess.PIPE,
94+
stderr=subprocess.STDOUT,
95+
)
96+
97+
mypy_output = ""
98+
while popen.poll() is None:
99+
mypy_output += popen.stdout.read(8192).decode()
100+
mypy_output += popen.stdout.read().decode()
101+
102+
for line in mypy_output.split("\n"):
103+
filepath = line.partition(":")[0]
104+
if filepath in TYPED_FILES:
105+
errors.append(line)
106+
if errors:
107+
session.error("\n" + "\n".join(sorted(set(errors))))
108+
109+
79110
@nox.session()
80111
def docs(session):
81112
session.install(".[develop]")

0 commit comments

Comments
 (0)