@@ -444,17 +444,20 @@ def __init__(self, distribution, lower, upper,
444
444
else :
445
445
transform = transforms .lowerbound (lower )
446
446
default = lower + 1
447
+ else :
448
+ default = None
447
449
448
- # We only change logp and testval for
449
- # discrete distributions
450
+ # We don't use transformations for dicrete variables
450
451
if issubclass (distribution , Discrete ):
451
452
transform = None
452
- if default is not None :
453
- default = default .astype (self .dist .default .type ())
454
453
454
+ kwargs ['transform' ] = transform
455
455
self ._wrapped = distribution .dist (* args , ** kwargs )
456
456
self ._default = default
457
457
458
+ if issubclass (distribution , Discrete ) and default is not None :
459
+ default = default .astype (str (self ._wrapped .default ().dtype ))
460
+
458
461
if default is None :
459
462
defaults = self ._wrapped .defaults
460
463
for name in defaults :
@@ -470,16 +473,16 @@ def __init__(self, distribution, lower, upper,
470
473
transform = self ._wrapped .transform )
471
474
472
475
def _random (self , lower , upper , point = None , size = None ):
473
- if lower is None :
474
- lower = - np .inf
475
- if upper is None :
476
- upper = np . inf
477
-
478
- samples = np .zeros (size ).flatten ()
476
+ lower = np . asarray ( lower )
477
+ upper = np .asarray ( upper )
478
+ if lower . size > 1 or upper . size > 1 :
479
+ raise ValueError ( 'Drawing samples from distributions with '
480
+ 'array-valued bounds is not supported.' )
481
+ samples = np .zeros (size , dtype = self . dtype ).flatten ()
479
482
i , n = 0 , len (samples )
480
483
while i < len (samples ):
481
484
sample = self ._wrapped .random (point = point , size = n )
482
- select = sample [np .logical_and (sample > lower , sample <= upper )]
485
+ select = sample [np .logical_and (sample >= lower , sample <= upper )]
483
486
samples [i :(i + len (select ))] = select [:]
484
487
i += len (select )
485
488
n -= len (select )
@@ -489,18 +492,31 @@ def _random(self, lower, upper, point=None, size=None):
489
492
return samples
490
493
491
494
def random (self , point = None , size = None , repeat = None ):
492
- lower , upper = draw_values ([self .lower , self .upper ], point = point )
493
- return generate_samples (self ._random , lower , upper , point ,
494
- dist_shape = self .shape ,
495
- size = size )
495
+ if self .lower is None and self .upper is None :
496
+ return self ._wrapped .random (point = point , size = size )
497
+ elif self .lower is not None and self .upper is not None :
498
+ lower , upper = draw_values ([self .lower , self .upper ], point = point )
499
+ return generate_samples (self ._random , lower , upper , point ,
500
+ dist_shape = self .shape ,
501
+ size = size )
502
+ elif self .lower is not None :
503
+ lower = draw_values ([self .lower ], point = point )
504
+ return generate_samples (self ._random , lower , np .inf , point ,
505
+ dist_shape = self .shape ,
506
+ size = size )
507
+ else :
508
+ upper = draw_values ([self .upper ], point = point )
509
+ return generate_samples (self ._random , - np .inf , upper , point ,
510
+ dist_shape = self .shape ,
511
+ size = size )
496
512
497
513
def logp (self , value ):
498
514
logp = self ._wrapped .logp (value )
499
515
bounds = []
500
516
if self .lower is not None :
501
- bounds .append (value > self .lower )
517
+ bounds .append (value >= self .lower )
502
518
if self .upper is not None :
503
- bounds .append (value < self .upper )
519
+ bounds .append (value <= self .upper )
504
520
if len (bounds ) > 0 :
505
521
return bound (logp , * bounds )
506
522
else :
@@ -516,13 +532,16 @@ class Bound(object):
516
532
truncated distributions, use `Bound` in combination with
517
533
a `pm.Potential` with the cumulative probability function.
518
534
535
+ The bounds are inclusive for discrete distributions.
536
+
519
537
Parameters
520
538
----------
521
539
distribution : pymc3 distribution
522
- Distribution to be transformed into a bounded distribution
540
+ Distribution to be transformed into a bounded distribution.
523
541
lower : float or array like, optional
524
- Lower bound of the distribution
542
+ Lower bound of the distribution.
525
543
upper : float or array like, optional
544
+ Upper bound of the distribution.
526
545
527
546
Example
528
547
-------
0 commit comments