Skip to content

Commit 8e0d5f2

Browse files
committed
filter-rewriting for dates
1 parent 9d599c7 commit 8e0d5f2

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch improves the performance of :func:`~hypothesis.strategies.from_type` with
4+
`pydantic.types.condate <https://docs.pydantic.dev/latest/api/types/#pydantic.types.condate>`__
5+
(:issue:`4000`).

hypothesis-python/src/hypothesis/internal/filtering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
236236
options = {
237237
# We're talking about op(arg, x) - the reverse of our usual intuition!
238238
operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x
239-
operator.le: {"min_value": arg}, # lambda x: arg <= x
240-
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
241-
operator.ge: {"max_value": arg}, # lambda x: arg >= x
239+
operator.le: {"min_value": arg}, # lambda x: arg <= x
240+
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
241+
operator.ge: {"max_value": arg}, # lambda x: arg >= x
242242
operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x
243243
# Special-case our default predicates for length bounds
244244
min_len: {"min_value": arg, "len": True},

hypothesis-python/src/hypothesis/strategies/_internal/datetime.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import datetime as dt
12+
import operator as op
1213
import zoneinfo
1314
from calendar import monthrange
14-
from functools import lru_cache
15+
from functools import lru_cache, partial
1516
from importlib import resources
1617
from pathlib import Path
1718
from typing import Optional
1819

1920
from hypothesis.errors import InvalidArgument
2021
from hypothesis.internal.validation import check_type, check_valid_interval
2122
from hypothesis.strategies._internal.core import sampled_from
22-
from hypothesis.strategies._internal.misc import just, none
23+
from hypothesis.strategies._internal.misc import just, none, nothing
2324
from hypothesis.strategies._internal.strategies import SearchStrategy
2425
from hypothesis.strategies._internal.utils import defines_strategy
2526

@@ -267,6 +268,37 @@ def do_draw(self, data):
267268
**draw_capped_multipart(data, self.min_value, self.max_value, DATENAMES)
268269
)
269270

271+
def filter(self, condition):
272+
if (
273+
isinstance(condition, partial)
274+
and len(args := condition.args) == 1
275+
and not condition.keywords
276+
and isinstance(arg := args[0], dt.date)
277+
and condition.func in (op.lt, op.le, op.eq, op.ge, op.gt)
278+
):
279+
try:
280+
arg += dt.timedelta(days={op.lt: 1, op.gt: -1}.get(condition.func, 0))
281+
except OverflowError: # gt date.max, or lt date.min
282+
return nothing()
283+
lo, hi = {
284+
# We're talking about op(arg, x) - the reverse of our usual intuition!
285+
op.lt: (arg, self.max_value), # lambda x: arg < x
286+
op.le: (arg, self.max_value), # lambda x: arg <= x
287+
op.eq: (arg, arg), # lambda x: arg == x
288+
op.ge: (self.min_value, arg), # lambda x: arg >= x
289+
op.gt: (self.min_value, arg), # lambda x: arg > x
290+
}[condition.func]
291+
lo = max(lo, self.min_value)
292+
hi = min(hi, self.max_value)
293+
print(lo, hi)
294+
if hi < lo:
295+
return nothing()
296+
if lo <= self.min_value and self.max_value <= hi:
297+
return self
298+
return dates(lo, hi)
299+
300+
return super().filter(condition)
301+
270302

271303
@defines_strategy(force_reusable_values=True)
272304
def dates(

hypothesis-python/tests/cover/test_filter_rewriting.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

11+
import datetime as dt
1112
import decimal
1213
import math
1314
import operator
@@ -624,3 +625,20 @@ def test_regex_filter_rewriting(data, strategy, pattern, method):
624625
@given(st.text().filter(re.compile("abc").sub))
625626
def test_error_on_method_which_requires_multiple_args(_):
626627
pass
628+
629+
630+
def test_dates_filter_rewriting():
631+
today = dt.date.today()
632+
633+
assert st.dates().filter(partial(operator.lt, dt.date.max)).is_empty
634+
assert st.dates().filter(partial(operator.gt, dt.date.min)).is_empty
635+
assert st.dates(min_value=today).filter(partial(operator.gt, today)).is_empty
636+
assert st.dates(max_value=today).filter(partial(operator.lt, today)).is_empty
637+
638+
bare = unwrap_strategies(st.dates())
639+
assert bare.filter(partial(operator.ge, dt.date.max)) is bare
640+
assert bare.filter(partial(operator.le, dt.date.min)) is bare
641+
642+
new = bare.filter(partial(operator.le, today))
643+
assert not new.is_empty
644+
assert new is not bare

0 commit comments

Comments
 (0)