-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Convert continuous and discrete distribution parameters to floatX or int32 #3300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Needs a rebase. |
@twiecki, some tests are failing when |
By chance I saw #2366, where it says that |
…xed test_distributions.py errors. Attempted a fix to test_variational_inference.py errors, which I cannot reproduce locally.
pymc3/distributions/continuous.py
Outdated
self.sigma = self.sd = sigma = tt.as_tensor_variable(sigma) | ||
self.nu = nu = tt.as_tensor_variable(nu) | ||
self.mu = mu = tt.as_tensor_variable(floatX(mu)) | ||
self.sigma = sigma = tt.as_tensor_variable(floatX(sigma)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.sigma = sigma = tt.as_tensor_variable(floatX(sigma)) | |
self.sigma = self.sd = sigma = tt.as_tensor_variable(floatX(sigma)) |
pymc3/tests/test_distributions.py
Outdated
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0,~\mathit{sigma}=10.0)$', | ||
r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$', | ||
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', | ||
r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sd}=1.0)$', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sd}=1.0)$', | |
r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$', |
pymc3/tests/test_distributions.py
Outdated
@@ -1298,11 +1295,11 @@ def setup_class(self): | |||
Y_obs = Normal('Y_obs', mu=mu, sigma=sigma, observed=Y) | |||
self.distributions = [alpha, sigma, mu, b, Y_obs] | |||
self.expected = ( | |||
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0,~\mathit{sigma}=10.0)$', | |||
r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$', | |||
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', | |
r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$', |
pymc3/tests/test_distributions.py
Outdated
r'$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$', | ||
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0,~\mathit{sigma}=10.0)$', | ||
r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$' | ||
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', | |
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$', |
pymc3/tests/test_distributions.py
Outdated
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0,~\mathit{sigma}=10.0)$', | ||
r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$' | ||
r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sd}=10.0)$', | ||
r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=f(\text{mu}),~\mathit{sd}=f(\text{sigma}))$' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=f(\text{mu}),~\mathit{sd}=f(\text{sigma}))$' | |
r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=f(\text{mu}),~\mathit{sigma}=f(\text{sigma}))$' |
|
Thanks @lucianopaz! |
This should solve issue #3223. While I was working on PR #3293 I also ran into this problem but decided to submit a separate PR for it. I only explicitly casted to
floatX
orint32
if the distribution's docstring said that the parameter should be a float or an int. The rest (multivariate, mixture and timeseries) I left unchanged.