Skip to content

Proposal: Dist shape refactor #1125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def dist(cls, *args, **kwargs):
return dist

def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype,
testval=None, defaults=None, transform=None):
testval=None, defaults=None, transform=None, *args, **kwargs):
r"""
Distributions are specified in terms of the shape of their support, the shape
of the space of independent instances and the shape of the space of replications.
Expand Down Expand Up @@ -175,6 +175,9 @@ def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype,
self.shape_reps = _as_tensor_shape_variable(shape_reps)
self.ndim_reps = tt.get_vector_length(self.shape_reps)

self.bcast = bcast
self.dtype = dtype

ndim_sum = self.ndim_supp + self.ndim_ind + self.ndim_reps
if ndim_sum == 0:
self.shape = tt.constant([], dtype='int64')
Expand All @@ -197,7 +200,7 @@ def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype,
testval = self.get_test_value(defaults=self.defaults)

self.testval = testval
self.type = tt.TensorType(str(dtype), bcast)
self.type = tt.TensorType(str(dtype), self.bcast)

def default(self):
return self.get_test_value(self.testval, self.defaults)
Expand Down
24 changes: 17 additions & 7 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ class TransformedDistribution(Distribution):
"""A distribution that has been transformed from one space into another."""

def __init__(self, dist, transform, *args, **kwargs):
"""
r"""
Parameters
----------
dist : Distribution
TODO
transform : Transform
args, kwargs
arguments to Distribution"""
TODO
"""
forward = transform.forward
testval = forward(dist.default())

Expand All @@ -61,9 +62,18 @@ def __init__(self, dist, transform, *args, **kwargs):
v = forward(FreeRV(name='v', distribution=dist))
self.type = v.type

# We can get the transformed support shape from a single dummy var in
# only the support (i.e. without the independent or replication dimensions).
shape_supp = forward(tt.alloc(1, *dist.shape_supp)).shape

# XXX: We assume these two shapes don't change under a transform.
shape_ind = dist.shape_ind
shape_reps = dist.shape_reps

super(TransformedDistribution, self).__init__(
v.shape.tag.test_value, v.dtype,
testval, dist.defaults, *args, **kwargs)
shape_supp, shape_ind, shape_reps,
v.broadcastable, v.dtype,
testval.tag.test_value, dist.defaults, *args, **kwargs)

if transform.name == 'stickbreaking':
b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False))
Expand Down Expand Up @@ -193,7 +203,7 @@ class StickBreaking(Transform):
Parameters
----------
eps : float, positive value
A small value for numerical stability in invlogit.
A small value for numerical stability in invlogit.
"""

name = "stickbreaking"
Expand Down Expand Up @@ -250,7 +260,7 @@ def backward(self, y):

def forward(self, x):
return x

def jacobian_det(self, x):
return 0

Expand Down
14 changes: 7 additions & 7 deletions pymc3/variational/advi.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def logp_(input):
r = MRG_RandomStreams(seed=random_seed)

if n_mcsamples == 1:
n = r.normal(size=inarray.tag.test_value.shape)
n = r.normal(size=np.shape(inarray.tag.test_value))
q = n * tt.exp(w) + u
elbo = logp_(q) + tt.sum(w) + 0.5 * l * (1 + tt.log(2.0 * np.pi))
else:
n = r.normal(size=(n_mcsamples, u.tag.test_value.shape[0]))
n = r.normal(size=(n_mcsamples, np.shape(u.tag.test_value)[0]))
qs = n * tt.exp(w) + u
logps, _ = theano.scan(fn=lambda q: logp_(q),
outputs_info=None,
Expand Down Expand Up @@ -255,7 +255,7 @@ def optimizer(loss, param):
i_int = i.astype('int64')
value = param_.get_value(borrow=True)
accu = theano.shared(
np.zeros(value.shape + (n_win,), dtype=value.dtype))
np.zeros(np.shape(value) + (n_win,), dtype=value.dtype))
grad = tt.grad(loss, param_)

# Append squared gradient vector to accu_new
Expand Down Expand Up @@ -324,17 +324,17 @@ def rvs(x):
for v in global_RVs:
u = theano.shared(vparams['means'][str(v)]).ravel()
w = theano.shared(vparams['stds'][str(v)]).ravel()
n = r.normal(size=u.tag.test_value.shape)
updates.update({v: (n * w + u).reshape(v.tag.test_value.shape)})
n = r.normal(size=np.shape(u.tag.test_value))
updates.update({v: (n * w + u).reshape(np.shape(v.tag.test_value))})

if local_RVs is not None:
for v_, (uw, _) in local_RVs.items():
v = get_transformed(v_)
u = uw[0].ravel()
w = uw[1].ravel()
n = r.normal(size=u.tag.test_value.shape)
n = r.normal(size=np.shape(u.tag.test_value))
updates.update(
{v: (n * tt.exp(w) + u).reshape(v.tag.test_value.shape)})
{v: (n * tt.exp(w) + u).reshape(np.shape(v.tag.test_value))})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_value was returning a primitive (not a numpy array) in the tests. Not sure if that's due to another issue, yet; still clearing out the big issues following the evaluation hack I just added. Will comment on all this once the tests are passing again (or sooner, if necessary).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, was just curious. Great to see you picking this back up!


# Replace some nodes of the graph with variational distributions
vars = model.free_RVs
Expand Down