16
16
# under the License.
17
17
18
18
import collections .abc
19
- from typing import Dict
19
+ from copy import deepcopy
20
+ from typing import Any , ClassVar , Dict , MutableMapping , Optional , Union , overload
20
21
21
- from .utils import DslBase
22
+ from .utils import DslBase , JSONType
22
23
23
24
24
- # Incomplete annotation to not break query.py tests
25
- def SF (name_or_sf , ** params ) -> "ScoreFunction" :
25
+ @overload
26
+ def SF (name_or_sf : MutableMapping [str , Any ]) -> "ScoreFunction" : ...
27
+
28
+
29
+ @overload
30
+ def SF (name_or_sf : "ScoreFunction" ) -> "ScoreFunction" : ...
31
+
32
+
33
+ @overload
34
+ def SF (name_or_sf : str , ** params : Any ) -> "ScoreFunction" : ...
35
+
36
+
37
+ def SF (
38
+ name_or_sf : Union [str , "ScoreFunction" , MutableMapping [str , Any ]],
39
+ ** params : Any ,
40
+ ) -> "ScoreFunction" :
26
41
# {"script_score": {"script": "_score"}, "filter": {}}
27
- if isinstance (name_or_sf , collections .abc .Mapping ):
42
+ if isinstance (name_or_sf , collections .abc .MutableMapping ):
28
43
if params :
29
44
raise ValueError ("SF() cannot accept parameters when passing in a dict." )
30
- kwargs = {}
31
- sf = name_or_sf .copy ()
45
+
46
+ kwargs : Dict [str , Any ] = {}
47
+ sf = deepcopy (name_or_sf )
32
48
for k in ScoreFunction ._param_defs :
33
49
if k in name_or_sf :
34
50
kwargs [k ] = sf .pop (k )
35
51
36
52
# not sf, so just filter+weight, which used to be boost factor
53
+ sf_params = params
37
54
if not sf :
38
55
name = "boost_factor"
39
56
# {'FUNCTION': {...}}
40
57
elif len (sf ) == 1 :
41
- name , params = sf .popitem ()
58
+ name , sf_params = sf .popitem ()
42
59
else :
43
60
raise ValueError (f"SF() got an unexpected fields in the dictionary: { sf !r} " )
44
61
45
62
# boost factor special case, see elasticsearch #6343
46
- if not isinstance (params , collections .abc .Mapping ):
47
- params = {"value" : params }
63
+ if not isinstance (sf_params , collections .abc .Mapping ):
64
+ sf_params = {"value" : sf_params }
48
65
49
66
# mix known params (from _param_defs) and from inside the function
50
- kwargs .update (params )
67
+ kwargs .update (sf_params )
51
68
return ScoreFunction .get_dsl_class (name )(** kwargs )
52
69
53
70
# ScriptScore(script="_score", filter=Q())
@@ -70,14 +87,16 @@ class ScoreFunction(DslBase):
70
87
"filter" : {"type" : "query" },
71
88
"weight" : {},
72
89
}
73
- name = None
90
+ name : ClassVar [ Optional [ str ]] = None
74
91
75
- def to_dict (self ):
92
+ def to_dict (self ) -> Dict [ str , JSONType ] :
76
93
d = super ().to_dict ()
77
94
# filter and query dicts should be at the same level as us
78
95
for k in self ._param_defs :
79
- if k in d [self .name ]:
80
- d [k ] = d [self .name ].pop (k )
96
+ if self .name is not None :
97
+ val = d [self .name ]
98
+ if isinstance (val , dict ) and k in val :
99
+ d [k ] = val .pop (k )
81
100
return d
82
101
83
102
@@ -88,12 +107,15 @@ class ScriptScore(ScoreFunction):
88
107
class BoostFactor (ScoreFunction ):
89
108
name = "boost_factor"
90
109
91
- def to_dict (self ) -> Dict [str , int ]:
110
+ def to_dict (self ) -> Dict [str , JSONType ]:
92
111
d = super ().to_dict ()
93
- if "value" in d [self .name ]:
94
- d [self .name ] = d [self .name ].pop ("value" )
95
- else :
96
- del d [self .name ]
112
+ if self .name is not None :
113
+ val = d [self .name ]
114
+ if isinstance (val , dict ):
115
+ if "value" in val :
116
+ d [self .name ] = val .pop ("value" )
117
+ else :
118
+ del d [self .name ]
97
119
return d
98
120
99
121
0 commit comments