diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 96a5b55490..277ed7936d 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -15,6 +15,7 @@ from pymc3.util import get_variable_name from .special import log_i0 from ..math import invlogit, logit, logdiffexp +from .shape_utils import broadcast_distribution_samples from .dist_math import ( alltrue_elemwise, betaln, bound, gammaln, i0e, incomplete_beta, logpow, normal_lccdf, normal_lcdf, SplineWrapper, std_cdf, zvalue, @@ -664,8 +665,13 @@ def random(self, point=None, size=None): ------- array """ - mu_v, std_v, a_v, b_v = draw_values( - [self.mu, self.sigma, self.lower, self.upper], point=point, size=size) + mu_v, std_v, a_v, b_v = broadcast_distribution_samples( + draw_values( + [self.mu, self.sigma, self.lower, self.upper], + point=point, + size=size), + size=size, + ) return generate_samples(stats.truncnorm.rvs, a=(a_v - mu_v)/std_v, b=(b_v - mu_v) / std_v,