Skip to content

Commit 3c6b95b

Browse files
committed
refactor: add typing to query tests
1 parent a82a8c1 commit 3c6b95b

File tree

6 files changed

+85
-73
lines changed

6 files changed

+85
-73
lines changed

elasticsearch_dsl/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from typing import Dict
1920

2021
from .utils import DslBase
2122

2223

23-
def SF(name_or_sf, **params):
24+
# Incomplete annotation to not break query.py tests
25+
def SF(name_or_sf, **params) -> "ScoreFunction":
2426
# {"script_score": {"script": "_score"}, "filter": {}}
2527
if isinstance(name_or_sf, collections.abc.Mapping):
2628
if params:
@@ -86,7 +88,7 @@ class ScriptScore(ScoreFunction):
8688
class BoostFactor(ScoreFunction):
8789
name = "boost_factor"
8890

89-
def to_dict(self):
91+
def to_dict(self) -> Dict[str, int]:
9092
d = super().to_dict()
9193
if "value" in d[self.name]:
9294
d[self.name] = d[self.name].pop("value")

elasticsearch_dsl/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Protocol,
2929
TypeVar,
3030
Union,
31+
cast,
3132
overload,
3233
)
3334

@@ -58,7 +59,7 @@ def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...
5859

5960

6061
@overload
61-
def Q(name_or_query: str, **params: Any) -> "Query": ...
62+
def Q(name_or_query: str = "match_all", **params: Any) -> "Query": ...
6263

6364

6465
def Q(
@@ -92,7 +93,7 @@ def Q(
9293

9394
# s.query = Q('filtered', query=s.query)
9495
if hasattr(name_or_query, "_proxied"):
95-
return name_or_query._proxied
96+
return cast(QProxiedProtocol[_T], name_or_query)._proxied
9697

9798
# "match", title="python"
9899
return Query.get_dsl_class(name_or_query)(**params)

elasticsearch_dsl/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import collections.abc
2020
from copy import copy
21-
from typing import Any, Optional, Type
21+
from typing import Any, Dict, Optional, Type
2222

2323
from typing_extensions import Self
2424

@@ -356,7 +356,8 @@ def __getattr__(self, name):
356356
return AttrDict(value)
357357
return value
358358

359-
def to_dict(self):
359+
# TODO: This type annotation can probably be made tighter
360+
def to_dict(self) -> Dict[str, Dict[str, Any]]:
360361
"""
361362
Serialize the DSL object to plain dict
362363
"""

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[mypy-elasticsearch_dsl.query]
2+
# Allow reexport of SF for tests
3+
implicit_reexport = True

noxfile.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
"utils/",
3030
)
3131

32-
TYPED_FILES = ("elasticsearch_dsl/query.py",)
32+
TYPED_FILES = (
33+
"elasticsearch_dsl/query.py",
34+
"tests/test_query.py",
35+
)
3336

3437

3538
@nox.session(
@@ -85,7 +88,7 @@ def type_check(session):
8588
session.install("mypy", ".[develop]")
8689
errors = []
8790
popen = subprocess.Popen(
88-
"mypy --strict elasticsearch_dsl",
91+
"mypy --strict elasticsearch_dsl tests",
8992
env=session.env,
9093
shell=True,
9194
stdout=subprocess.PIPE,

0 commit comments

Comments
 (0)