Skip to content

Commit 2acc0ee

Browse files
authored
Merge pull request #4137 from Zac-HD/filter-rewrites-3
Filter-rewriting for more efficient `from_type(pydantic.types.condate)`
2 parents 8a01483 + 8e0d5f2 commit 2acc0ee

File tree

5 files changed

+72
-14
lines changed

5 files changed

+72
-14
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch improves the performance of :func:`~hypothesis.strategies.from_type` with
4+
`pydantic.types.condate <https://docs.pydantic.dev/latest/api/types/#pydantic.types.condate>`__
5+
(:issue:`4000`).

hypothesis-python/src/hypothesis/internal/filtering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
236236
options = {
237237
# We're talking about op(arg, x) - the reverse of our usual intuition!
238238
operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x
239-
operator.le: {"min_value": arg}, # lambda x: arg <= x
240-
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
241-
operator.ge: {"max_value": arg}, # lambda x: arg >= x
239+
operator.le: {"min_value": arg}, # lambda x: arg <= x
240+
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
241+
operator.ge: {"max_value": arg}, # lambda x: arg >= x
242242
operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x
243243
# Special-case our default predicates for length bounds
244244
min_len: {"min_value": arg, "len": True},

hypothesis-python/src/hypothesis/provisional.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# https://tools.ietf.org/html/rfc3696
2020

2121
import string
22+
from functools import lru_cache
2223
from importlib import resources
2324
from typing import Optional
2425

@@ -31,15 +32,17 @@
3132
FRAGMENT_SAFE_CHARACTERS = URL_SAFE_CHARACTERS | {"?", "/"}
3233

3334

34-
# This file is sourced from http://data.iana.org/TLD/tlds-alpha-by-domain.txt
35-
# The file contains additional information about the date that it was last updated.
36-
traversable = resources.files("hypothesis.vendor") / "tlds-alpha-by-domain.txt"
37-
_comment, *_tlds = traversable.read_text(encoding="utf-8").splitlines()
38-
assert _comment.startswith("#")
35+
@lru_cache(maxsize=1)
36+
def get_top_level_domains() -> tuple[str, ...]:
37+
# This file is sourced from http://data.iana.org/TLD/tlds-alpha-by-domain.txt
38+
# The file contains additional information about the date that it was last updated.
39+
traversable = resources.files("hypothesis.vendor") / "tlds-alpha-by-domain.txt"
40+
_comment, *_tlds = traversable.read_text(encoding="utf-8").splitlines()
41+
assert _comment.startswith("#")
3942

40-
# Remove special-use domain names from the list. For more discussion
41-
# see https://github.com/HypothesisWorks/hypothesis/pull/3572
42-
TOP_LEVEL_DOMAINS = ["COM", *sorted((d for d in _tlds if d != "ARPA"), key=len)]
43+
# Remove special-use domain names from the list. For more discussion
44+
# see https://github.com/HypothesisWorks/hypothesis/pull/3572
45+
return ("COM", *sorted((d for d in _tlds if d != "ARPA"), key=len))
4346

4447

4548
class DomainNameStrategy(st.SearchStrategy):
@@ -101,7 +104,7 @@ def do_draw(self, data):
101104
# prevent us from generating at least a 1 character subdomain.
102105
# 3 - Randomize the TLD between upper and lower case characters.
103106
domain = data.draw(
104-
st.sampled_from(TOP_LEVEL_DOMAINS)
107+
st.sampled_from(get_top_level_domains())
105108
.filter(lambda tld: len(tld) + 2 <= self.max_length)
106109
.flatmap(
107110
lambda tld: st.tuples(

hypothesis-python/src/hypothesis/strategies/_internal/datetime.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import datetime as dt
12+
import operator as op
1213
import zoneinfo
1314
from calendar import monthrange
14-
from functools import lru_cache
15+
from functools import lru_cache, partial
1516
from importlib import resources
1617
from pathlib import Path
1718
from typing import Optional
1819

1920
from hypothesis.errors import InvalidArgument
2021
from hypothesis.internal.validation import check_type, check_valid_interval
2122
from hypothesis.strategies._internal.core import sampled_from
22-
from hypothesis.strategies._internal.misc import just, none
23+
from hypothesis.strategies._internal.misc import just, none, nothing
2324
from hypothesis.strategies._internal.strategies import SearchStrategy
2425
from hypothesis.strategies._internal.utils import defines_strategy
2526

@@ -267,6 +268,37 @@ def do_draw(self, data):
267268
**draw_capped_multipart(data, self.min_value, self.max_value, DATENAMES)
268269
)
269270

271+
def filter(self, condition):
272+
if (
273+
isinstance(condition, partial)
274+
and len(args := condition.args) == 1
275+
and not condition.keywords
276+
and isinstance(arg := args[0], dt.date)
277+
and condition.func in (op.lt, op.le, op.eq, op.ge, op.gt)
278+
):
279+
try:
280+
arg += dt.timedelta(days={op.lt: 1, op.gt: -1}.get(condition.func, 0))
281+
except OverflowError: # gt date.max, or lt date.min
282+
return nothing()
283+
lo, hi = {
284+
# We're talking about op(arg, x) - the reverse of our usual intuition!
285+
op.lt: (arg, self.max_value), # lambda x: arg < x
286+
op.le: (arg, self.max_value), # lambda x: arg <= x
287+
op.eq: (arg, arg), # lambda x: arg == x
288+
op.ge: (self.min_value, arg), # lambda x: arg >= x
289+
op.gt: (self.min_value, arg), # lambda x: arg > x
290+
}[condition.func]
291+
lo = max(lo, self.min_value)
292+
hi = min(hi, self.max_value)
293+
print(lo, hi)
294+
if hi < lo:
295+
return nothing()
296+
if lo <= self.min_value and self.max_value <= hi:
297+
return self
298+
return dates(lo, hi)
299+
300+
return super().filter(condition)
301+
270302

271303
@defines_strategy(force_reusable_values=True)
272304
def dates(

hypothesis-python/tests/cover/test_filter_rewriting.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

11+
import datetime as dt
1112
import decimal
1213
import math
1314
import operator
@@ -624,3 +625,20 @@ def test_regex_filter_rewriting(data, strategy, pattern, method):
624625
@given(st.text().filter(re.compile("abc").sub))
625626
def test_error_on_method_which_requires_multiple_args(_):
626627
pass
628+
629+
630+
def test_dates_filter_rewriting():
631+
today = dt.date.today()
632+
633+
assert st.dates().filter(partial(operator.lt, dt.date.max)).is_empty
634+
assert st.dates().filter(partial(operator.gt, dt.date.min)).is_empty
635+
assert st.dates(min_value=today).filter(partial(operator.gt, today)).is_empty
636+
assert st.dates(max_value=today).filter(partial(operator.lt, today)).is_empty
637+
638+
bare = unwrap_strategies(st.dates())
639+
assert bare.filter(partial(operator.ge, dt.date.max)) is bare
640+
assert bare.filter(partial(operator.le, dt.date.min)) is bare
641+
642+
new = bare.filter(partial(operator.le, today))
643+
assert not new.is_empty
644+
assert new is not bare

0 commit comments

Comments
 (0)