@@ -88,7 +88,7 @@ def wrapper(tp):
88
88
class IntegerKWargs (TypedDict ):
89
89
min_value : Optional [int ]
90
90
max_value : Optional [int ]
91
- weights : Optional [Sequence [ float ]]
91
+ weights : Optional [dict [ int , float ]]
92
92
shrink_towards : int
93
93
94
94
@@ -1287,7 +1287,7 @@ def draw_integer(
1287
1287
max_value : Optional [int ] = None ,
1288
1288
* ,
1289
1289
# weights are for choosing an element index from a bounded range
1290
- weights : Optional [Sequence [ float ]] = None ,
1290
+ weights : Optional [dict [ int , float ]] = None ,
1291
1291
shrink_towards : int = 0 ,
1292
1292
forced : Optional [int ] = None ,
1293
1293
fake_forced : bool = False ,
@@ -1456,8 +1456,7 @@ def draw_integer(
1456
1456
min_value : Optional [int ] = None ,
1457
1457
max_value : Optional [int ] = None ,
1458
1458
* ,
1459
- # weights are for choosing an element index from a bounded range
1460
- weights : Optional [Sequence [float ]] = None ,
1459
+ weights : Optional [dict [int , float ]] = None ,
1461
1460
shrink_towards : int = 0 ,
1462
1461
forced : Optional [int ] = None ,
1463
1462
fake_forced : bool = False ,
@@ -1475,22 +1474,31 @@ def draw_integer(
1475
1474
assert min_value is not None
1476
1475
assert max_value is not None
1477
1476
1478
- sampler = Sampler (weights , observe = False )
1479
- gap = max_value - shrink_towards
1480
-
1481
- forced_idx = None
1482
- if forced is not None :
1483
- if forced >= shrink_towards :
1484
- forced_idx = forced - shrink_towards
1485
- else :
1486
- forced_idx = shrink_towards + gap - forced
1487
- idx = sampler .sample (self ._cd , forced = forced_idx , fake_forced = fake_forced )
1477
+ # format of weights is a mapping of ints to p, where sum(p) < 1.
1478
+ # The remaining probability mass is uniformly distributed over
1479
+ # *all* ints (not just the unmapped ones; this is somewhat undesirable,
1480
+ # but simplifies things).
1481
+ #
1482
+ # We assert that sum(p) is strictly less than 1 because it simplifies
1483
+ # handling forced values when we can force into the unmapped probability
1484
+ # mass. We should eventually remove this restriction.
1485
+ sampler = Sampler (
1486
+ [1 - sum (weights .values ()), * weights .values ()], observe = False
1487
+ )
1488
+ # if we're forcing, it's easiest to force into the unmapped probability
1489
+ # mass and then force the drawn value after.
1490
+ idx = sampler .sample (
1491
+ self ._cd , forced = None if forced is None else 0 , fake_forced = fake_forced
1492
+ )
1488
1493
1489
- # For range -2..2, interpret idx = 0..4 as [0, 1, 2, -1, -2]
1490
- if idx <= gap :
1491
- return shrink_towards + idx
1492
- else :
1493
- return shrink_towards - (idx - gap )
1494
+ return self ._draw_bounded_integer (
1495
+ min_value ,
1496
+ max_value ,
1497
+ # implicit reliance on dicts being sorted for determinism
1498
+ forced = forced if idx == 0 else list (weights )[idx - 1 ],
1499
+ center = shrink_towards ,
1500
+ fake_forced = fake_forced ,
1501
+ )
1494
1502
1495
1503
if min_value is None and max_value is None :
1496
1504
return self ._draw_unbounded_integer (forced = forced , fake_forced = fake_forced )
@@ -2116,8 +2124,7 @@ def draw_integer(
2116
2124
min_value : Optional [int ] = None ,
2117
2125
max_value : Optional [int ] = None ,
2118
2126
* ,
2119
- # weights are for choosing an element index from a bounded range
2120
- weights : Optional [Sequence [float ]] = None ,
2127
+ weights : Optional [dict [int , float ]] = None ,
2121
2128
shrink_towards : int = 0 ,
2122
2129
forced : Optional [int ] = None ,
2123
2130
fake_forced : bool = False ,
@@ -2127,9 +2134,14 @@ def draw_integer(
2127
2134
if weights is not None :
2128
2135
assert min_value is not None
2129
2136
assert max_value is not None
2130
- width = max_value - min_value + 1
2131
- assert width <= 255 # arbitrary practical limit
2132
- assert len (weights ) == width
2137
+ assert len (weights ) <= 255 # arbitrary practical limit
2138
+ # We can and should eventually support total weights. But this
2139
+ # complicates shrinking as we can no longer assume we can force
2140
+ # a value to the unmapped probability mass if that mass might be 0.
2141
+ assert sum (weights .values ()) < 1
2142
+ # similarly, things get simpler if we assume every value is possible.
2143
+ # we'll want to drop this restriction eventually.
2144
+ assert all (w != 0 for w in weights .values ())
2133
2145
2134
2146
if forced is not None and (min_value is None or max_value is None ):
2135
2147
# We draw `forced=forced - shrink_towards` here internally, after clamping.
@@ -2365,18 +2377,7 @@ def _pooled_kwargs(self, ir_type, kwargs):
2365
2377
if self .provider .avoid_realization :
2366
2378
return kwargs
2367
2379
2368
- key = []
2369
- for k , v in kwargs .items ():
2370
- if ir_type == "float" and k in ["min_value" , "max_value" ]:
2371
- # handle -0.0 vs 0.0, etc.
2372
- v = float_to_int (v )
2373
- elif ir_type == "integer" and k == "weights" :
2374
- # make hashable
2375
- v = v if v is None else tuple (v )
2376
- key .append ((k , v ))
2377
-
2378
- key = (ir_type , * sorted (key ))
2379
-
2380
+ key = (ir_type , * ir_kwargs_key (ir_type , kwargs ))
2380
2381
try :
2381
2382
return POOLED_KWARGS_CACHE [key ]
2382
2383
except KeyError :
0 commit comments