Skip to content

Commit e63d0ce

Browse files
type hits for tests and examples
1 parent d045a99 commit e63d0ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1471
-1005
lines changed

elasticsearch_dsl/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
from . import connections
19-
from .aggs import A
19+
from .aggs import A, Agg
2020
from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer
2121
from .document import AsyncDocument, Document
2222
from .document_base import InnerDoc, M, MetaField, mapped_field
@@ -81,7 +81,8 @@
8181
from .function import SF
8282
from .index import AsyncIndex, AsyncIndexTemplate, Index, IndexTemplate
8383
from .mapping import AsyncMapping, Mapping
84-
from .query import Q
84+
from .query import Q, Query
85+
from .response import AggResponse, Response, UpdateByQueryResponse
8586
from .search import (
8687
AsyncEmptySearch,
8788
AsyncMultiSearch,
@@ -99,6 +100,8 @@
99100
__versionstr__ = ".".join(map(str, VERSION))
100101
__all__ = [
101102
"A",
103+
"Agg",
104+
"AggResponse",
102105
"AsyncDocument",
103106
"AsyncEmptySearch",
104107
"AsyncFacetedSearch",
@@ -158,11 +161,13 @@
158161
"Object",
159162
"Percolator",
160163
"Q",
164+
"Query",
161165
"Range",
162166
"RangeFacet",
163167
"RangeField",
164168
"RankFeature",
165169
"RankFeatures",
170+
"Response",
166171
"SF",
167172
"ScaledFloat",
168173
"Search",
@@ -174,6 +179,7 @@
174179
"TokenCount",
175180
"UnknownDslObject",
176181
"UpdateByQuery",
182+
"UpdateByQueryResponse",
177183
"ValidationException",
178184
"analyzer",
179185
"char_filter",

elasticsearch_dsl/_async/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
name: str,
3939
template: str,
4040
index: Optional["AsyncIndex"] = None,
41-
order: Optional[str] = None,
41+
order: Optional[int] = None,
4242
**kwargs: Any,
4343
):
4444
if index is None:
@@ -100,7 +100,7 @@ def as_template(
100100
self,
101101
template_name: str,
102102
pattern: Optional[str] = None,
103-
order: Optional[str] = None,
103+
order: Optional[int] = None,
104104
) -> AsyncIndexTemplate:
105105
# TODO: should we allow pattern to be a top-level arg?
106106
# or maybe have an IndexPattern that allows for it and have

elasticsearch_dsl/_async/search.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
# under the License.
1717

1818
import contextlib
19-
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
AsyncIterator,
23+
Dict,
24+
Iterator,
25+
List,
26+
Optional,
27+
cast,
28+
)
2029

2130
from elasticsearch.exceptions import ApiError
2231
from elasticsearch.helpers import async_scan
@@ -68,6 +77,7 @@ async def count(self) -> int:
6877
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
6978
**self._params,
7079
)
80+
7181
return cast(int, resp["count"])
7282

7383
async def execute(self, ignore_cache: bool = False) -> Response[_R]:
@@ -175,6 +185,10 @@ class AsyncMultiSearch(MultiSearchBase[_R]):
175185

176186
_using: AsyncUsingType
177187

188+
if TYPE_CHECKING:
189+
190+
def add(self, search: AsyncSearch[_R]) -> Self: ... # type: ignore[override]
191+
178192
async def execute(
179193
self, ignore_cache: bool = False, raise_on_error: bool = True
180194
) -> List[Response[_R]]:

elasticsearch_dsl/_sync/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
name: str,
3939
template: str,
4040
index: Optional["Index"] = None,
41-
order: Optional[str] = None,
41+
order: Optional[int] = None,
4242
**kwargs: Any,
4343
):
4444
if index is None:
@@ -94,7 +94,7 @@ def as_template(
9494
self,
9595
template_name: str,
9696
pattern: Optional[str] = None,
97-
order: Optional[str] = None,
97+
order: Optional[int] = None,
9898
) -> IndexTemplate:
9999
# TODO: should we allow pattern to be a top-level arg?
100100
# or maybe have an IndexPattern that allows for it and have

elasticsearch_dsl/_sync/search.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import contextlib
19-
from typing import Any, Dict, Iterator, List, Optional, cast
19+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast
2020

2121
from elasticsearch.exceptions import ApiError
2222
from elasticsearch.helpers import scan
@@ -68,6 +68,7 @@ def count(self) -> int:
6868
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
6969
**self._params,
7070
)
71+
7172
return cast(int, resp["count"])
7273

7374
def execute(self, ignore_cache: bool = False) -> Response[_R]:
@@ -169,6 +170,10 @@ class MultiSearch(MultiSearchBase[_R]):
169170

170171
_using: UsingType
171172

173+
if TYPE_CHECKING:
174+
175+
def add(self, search: Search[_R]) -> Self: ... # type: ignore[override]
176+
172177
def execute(
173178
self, ignore_cache: bool = False, raise_on_error: bool = True
174179
) -> List[Response[_R]]:

elasticsearch_dsl/document_base.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,11 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
166166
field_defaults = {}
167167
for name in fields:
168168
value = None
169-
if name in attrs:
170-
# this field has a right-side value, which can be field
171-
# instance on its own or wrapped with mapped_field()
172-
value = attrs[name]
173-
if isinstance(value, dict):
174-
# the mapped_field() wrapper function was used so we need
175-
# to look for the field instance and also record any
176-
# dataclass-style defaults
177-
value = attrs[name].get("_field")
178-
default_value = attrs[name].get("default") or attrs[name].get(
179-
"default_factory"
180-
)
181-
if default_value:
182-
field_defaults[name] = default_value
183-
if value is None:
184-
# the field does not have an explicit field instance given in
185-
# a right-side assignment, so we need to figure out what field
186-
# type to use from the annotation
169+
required = None
170+
multi = None
171+
if name in annotations:
172+
# the field has a type annotation, so next we try to figure out
173+
# what field type we can use
187174
type_ = annotations[name]
188175
required = True
189176
multi = False
@@ -201,24 +188,53 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
201188
elif type_.__origin__ in [list, List]:
202189
# List[type] -> mark instance as multi
203190
multi = True
191+
required = False
204192
type_ = type_.__args__[0]
205193
else:
206194
break
195+
field = None
207196
field_args: List[Any] = []
208197
field_kwargs: Dict[str, Any] = {}
209-
if not isinstance(type_, type):
210-
raise TypeError(f"Cannot map type {type_}")
211-
elif issubclass(type_, InnerDoc):
198+
if isinstance(type_, type) and issubclass(type_, InnerDoc):
212199
# object or nested field
213200
field = Nested if multi else Object
214201
field_args = [type_]
215202
elif type_ in self.type_annotation_map:
216203
# use best field type for the type hint provided
217204
field, field_kwargs = self.type_annotation_map[type_]
218-
else:
219-
raise TypeError(f"Cannot map type {type_}")
220-
field_kwargs = {"multi": multi, "required": required, **field_kwargs}
221-
value = field(*field_args, **field_kwargs)
205+
206+
if field:
207+
field_kwargs = {
208+
"multi": multi,
209+
"required": required,
210+
**field_kwargs,
211+
}
212+
value = field(*field_args, **field_kwargs)
213+
214+
if name in attrs:
215+
# this field has a right-side value, which can be field
216+
# instance on its own or wrapped with mapped_field()
217+
attr_value = attrs[name]
218+
if isinstance(attr_value, dict):
219+
# the mapped_field() wrapper function was used so we need
220+
# to look for the field instance and also record any
221+
# dataclass-style defaults
222+
attr_value = attrs[name].get("_field")
223+
default_value = attrs[name].get("default") or attrs[name].get(
224+
"default_factory"
225+
)
226+
if default_value:
227+
field_defaults[name] = default_value
228+
if attr_value:
229+
value = attr_value
230+
if required is not None:
231+
value._required = required
232+
if multi is not None:
233+
value._multi = multi
234+
235+
if value is None:
236+
raise TypeError(f"Cannot map field {name}")
237+
222238
self.mapping.field(name, value)
223239
if name in attrs:
224240
del attrs[name]

elasticsearch_dsl/faceted_search_base.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,19 @@
1616
# under the License.
1717

1818
from datetime import datetime, timedelta
19-
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Union, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Dict,
23+
Generic,
24+
List,
25+
Optional,
26+
Sequence,
27+
Tuple,
28+
Type,
29+
Union,
30+
cast,
31+
)
2032

2133
from typing_extensions import Self
2234

@@ -26,10 +38,11 @@
2638
from .utils import _R, AttrDict
2739

2840
if TYPE_CHECKING:
41+
from .document_base import DocumentBase
2942
from .response.aggs import BucketData
3043
from .search_base import SearchBase
3144

32-
FilterValueType = Union[str, datetime]
45+
FilterValueType = Union[str, datetime, Sequence[str]]
3346

3447
__all__ = [
3548
"FacetedSearchBase",
@@ -51,7 +64,7 @@ class Facet(Generic[_R]):
5164
agg_type: str = ""
5265

5366
def __init__(
54-
self, metric: Optional[str] = None, metric_sort: str = "desc", **kwargs: Any
67+
self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any
5568
):
5669
self.filter_values = ()
5770
self._params = kwargs
@@ -137,7 +150,9 @@ def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
137150
class RangeFacet(Facet[_R]):
138151
agg_type = "range"
139152

140-
def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
153+
def _range_to_dict(
154+
self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]]
155+
) -> Dict[str, Any]:
141156
key, _range = range
142157
out: Dict[str, Any] = {"key": key}
143158
if _range[0] is not None:
@@ -146,7 +161,11 @@ def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
146161
out["to"] = _range[1]
147162
return out
148163

149-
def __init__(self, ranges: List[Tuple[Any, Tuple[int, int]]], **kwargs: Any):
164+
def __init__(
165+
self,
166+
ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]],
167+
**kwargs: Any,
168+
):
150169
super().__init__(**kwargs)
151170
self._params["ranges"] = list(map(self._range_to_dict, ranges))
152171
self._params["keyed"] = False
@@ -277,7 +296,7 @@ class FacetedResponse(Response[_R]):
277296
_facets: Dict[str, List[Tuple[Any, int, bool]]]
278297

279298
@property
280-
def query_string(self) -> Optional[Query]:
299+
def query_string(self) -> Optional[Union[str, Query]]:
281300
return self._faceted_search._query
282301

283302
@property
@@ -334,9 +353,9 @@ def search(self):
334353
335354
"""
336355

337-
index = None
338-
doc_types = None
339-
fields: List[str] = []
356+
index: Optional[str] = None
357+
doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None
358+
fields: Sequence[str] = []
340359
facets: Dict[str, Facet[_R]] = {}
341360
using = "default"
342361

@@ -346,9 +365,9 @@ def search(self) -> "SearchBase[_R]": ...
346365

347366
def __init__(
348367
self,
349-
query: Optional[Query] = None,
368+
query: Optional[Union[str, Query]] = None,
350369
filters: Dict[str, FilterValueType] = {},
351-
sort: List[str] = [],
370+
sort: Sequence[str] = [],
352371
):
353372
"""
354373
:arg query: the text to search for
@@ -383,16 +402,18 @@ def add_filter(
383402
]
384403

385404
# remember the filter values for use in FacetedResponse
386-
self.filter_values[name] = filter_values
405+
self.filter_values[name] = filter_values # type: ignore[assignment]
387406

388407
# get the filter from the facet
389-
f = self.facets[name].add_filter(filter_values)
408+
f = self.facets[name].add_filter(filter_values) # type: ignore[arg-type]
390409
if f is None:
391410
return
392411

393412
self._filters[name] = f
394413

395-
def query(self, search: "SearchBase[_R]", query: Query) -> "SearchBase[_R]":
414+
def query(
415+
self, search: "SearchBase[_R]", query: Union[str, Query]
416+
) -> "SearchBase[_R]":
396417
"""
397418
Add query part to ``search``.
398419

0 commit comments

Comments
 (0)