Skip to content

Commit dcd39b2

Browse files
Add Type hints to function.py (#1827) (#1832)
* feat: add first type annotations * feat: add _JSONSafeTypes annotation to to_dict methods * feat: add typing for SF function * chore: fix linting * rename _JSONSafeTypes to JSONType * format code --------- Co-authored-by: Miguel Grinberg <[email protected]> (cherry picked from commit 76a57fd) Co-authored-by: Caio Fontes <[email protected]>
1 parent 9352f02 commit dcd39b2

File tree

3 files changed

+50
-25
lines changed

3 files changed

+50
-25
lines changed

elasticsearch_dsl/function.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,55 @@
1616
# under the License.
1717

1818
import collections.abc
19-
from typing import Dict
19+
from copy import deepcopy
20+
from typing import Any, ClassVar, Dict, MutableMapping, Optional, Union, overload
2021

21-
from .utils import DslBase
22+
from .utils import DslBase, JSONType
2223

2324

24-
# Incomplete annotation to not break query.py tests
25-
def SF(name_or_sf, **params) -> "ScoreFunction":
25+
@overload
26+
def SF(name_or_sf: MutableMapping[str, Any]) -> "ScoreFunction": ...
27+
28+
29+
@overload
30+
def SF(name_or_sf: "ScoreFunction") -> "ScoreFunction": ...
31+
32+
33+
@overload
34+
def SF(name_or_sf: str, **params: Any) -> "ScoreFunction": ...
35+
36+
37+
def SF(
38+
name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]],
39+
**params: Any,
40+
) -> "ScoreFunction":
2641
# {"script_score": {"script": "_score"}, "filter": {}}
27-
if isinstance(name_or_sf, collections.abc.Mapping):
42+
if isinstance(name_or_sf, collections.abc.MutableMapping):
2843
if params:
2944
raise ValueError("SF() cannot accept parameters when passing in a dict.")
30-
kwargs = {}
31-
sf = name_or_sf.copy()
45+
46+
kwargs: Dict[str, Any] = {}
47+
sf = deepcopy(name_or_sf)
3248
for k in ScoreFunction._param_defs:
3349
if k in name_or_sf:
3450
kwargs[k] = sf.pop(k)
3551

3652
# not sf, so just filter+weight, which used to be boost factor
53+
sf_params = params
3754
if not sf:
3855
name = "boost_factor"
3956
# {'FUNCTION': {...}}
4057
elif len(sf) == 1:
41-
name, params = sf.popitem()
58+
name, sf_params = sf.popitem()
4259
else:
4360
raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}")
4461

4562
# boost factor special case, see elasticsearch #6343
46-
if not isinstance(params, collections.abc.Mapping):
47-
params = {"value": params}
63+
if not isinstance(sf_params, collections.abc.Mapping):
64+
sf_params = {"value": sf_params}
4865

4966
# mix known params (from _param_defs) and from inside the function
50-
kwargs.update(params)
67+
kwargs.update(sf_params)
5168
return ScoreFunction.get_dsl_class(name)(**kwargs)
5269

5370
# ScriptScore(script="_score", filter=Q())
@@ -70,14 +87,16 @@ class ScoreFunction(DslBase):
7087
"filter": {"type": "query"},
7188
"weight": {},
7289
}
73-
name = None
90+
name: ClassVar[Optional[str]] = None
7491

75-
def to_dict(self):
92+
def to_dict(self) -> Dict[str, JSONType]:
7693
d = super().to_dict()
7794
# filter and query dicts should be at the same level as us
7895
for k in self._param_defs:
79-
if k in d[self.name]:
80-
d[k] = d[self.name].pop(k)
96+
if self.name is not None:
97+
val = d[self.name]
98+
if isinstance(val, dict) and k in val:
99+
d[k] = val.pop(k)
81100
return d
82101

83102

@@ -88,12 +107,15 @@ class ScriptScore(ScoreFunction):
88107
class BoostFactor(ScoreFunction):
89108
name = "boost_factor"
90109

91-
def to_dict(self) -> Dict[str, int]:
110+
def to_dict(self) -> Dict[str, JSONType]:
92111
d = super().to_dict()
93-
if "value" in d[self.name]:
94-
d[self.name] = d[self.name].pop("value")
95-
else:
96-
del d[self.name]
112+
if self.name is not None:
113+
val = d[self.name]
114+
if isinstance(val, dict):
115+
if "value" in val:
116+
d[self.name] = val.pop("value")
117+
else:
118+
del d[self.name]
97119
return d
98120

99121

elasticsearch_dsl/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
import collections.abc
2020
from copy import copy
21-
from typing import Any, Dict, Optional, Type
21+
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
2222

2323
from typing_extensions import Self
2424

2525
from .exceptions import UnknownDslObject, ValidationException
2626

27+
JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
28+
2729
SKIP_VALUES = ("", None)
2830
EXPAND__TO_DOT = True
2931

@@ -210,7 +212,7 @@ class DslMeta(type):
210212
For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`.
211213
"""
212214

213-
_types = {}
215+
_types: ClassVar[Dict[str, Type["DslBase"]]] = {}
214216

215217
def __init__(cls, name, bases, attrs):
216218
super().__init__(name, bases, attrs)
@@ -251,7 +253,8 @@ class DslBase(metaclass=DslMeta):
251253
all values in the `must` attribute into Query objects)
252254
"""
253255

254-
_param_defs = {}
256+
_type_name: ClassVar[str]
257+
_param_defs: ClassVar[Dict[str, Dict[str, Union[str, bool]]]] = {}
255258

256259
@classmethod
257260
def get_dsl_class(
@@ -356,8 +359,7 @@ def __getattr__(self, name):
356359
return AttrDict(value)
357360
return value
358361

359-
# TODO: This type annotation can probably be made tighter
360-
def to_dict(self) -> Dict[str, Dict[str, Any]]:
362+
def to_dict(self) -> Dict[str, JSONType]:
361363
"""
362364
Serialize the DSL object to plain dict
363365
"""

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131

3232
TYPED_FILES = (
33+
"elasticsearch_dsl/function.py",
3334
"elasticsearch_dsl/query.py",
3435
"tests/test_query.py",
3536
)

0 commit comments

Comments
 (0)