Skip to content

Commit 0a0a3cd

Browse files
ferrinemichaelosthege
authored andcommitted
add scaling for VI
1 parent dd623da commit 0a0a3cd

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

pymc3/model.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ def register_rv(
11651165
name = self.name_for(name)
11661166
rv_var.name = name
11671167
rv_var.tag.total_size = total_size
1168+
rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim)
11681169

11691170
# Associate previously unknown dimension names with
11701171
# the length of the corresponding RV dimension.
@@ -1824,6 +1825,68 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
18241825
return var
18251826

18261827

1828+
def _get_scaling(total_size, shape, ndim):
1829+
"""
1830+
Gets scaling constant for logp
1831+
Parameters
1832+
----------
1833+
total_size: int or list[int]
1834+
shape: shape
1835+
shape to scale
1836+
ndim: int
1837+
ndim hint
1838+
Returns
1839+
-------
1840+
scalar
1841+
"""
1842+
if total_size is None:
1843+
coef = 1.
1844+
elif isinstance(total_size, int):
1845+
if ndim >= 1:
1846+
denom = shape[0]
1847+
else:
1848+
denom = 1
1849+
coef = total_size / denom
1850+
elif isinstance(total_size, (list, tuple)):
1851+
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
1852+
raise TypeError(
1853+
"Unrecognized `total_size` type, expected "
1854+
"int or list of ints, got %r" % total_size
1855+
)
1856+
if Ellipsis in total_size:
1857+
sep = total_size.index(Ellipsis)
1858+
begin = total_size[:sep]
1859+
end = total_size[sep + 1 :]
1860+
if Ellipsis in end:
1861+
raise ValueError(
1862+
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
1863+
)
1864+
else:
1865+
begin = total_size
1866+
end = []
1867+
if (len(begin) + len(end)) > ndim:
1868+
raise ValueError(
1869+
"Length of `total_size` is too big, "
1870+
"number of scalings is bigger that ndim, got %r" % total_size
1871+
)
1872+
elif (len(begin) + len(end)) == 0:
1873+
coef = 1.
1874+
if len(end) > 0:
1875+
shp_end = shape[-len(end) :]
1876+
else:
1877+
shp_end = np.asarray([])
1878+
shp_begin = shape[: len(begin)]
1879+
begin_coef = [t / shp_begin[i] for i, t in enumerate(begin) if t is not None]
1880+
end_coef = [t / shp_end[i] for i, t in enumerate(end) if t is not None]
1881+
coefs = begin_coef + end_coef
1882+
coef = at.prod(coefs)
1883+
else:
1884+
raise TypeError(
1885+
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
1886+
)
1887+
return at.as_tensor(coef, dtype=aesara.config.floatX)
1888+
1889+
18271890
def Potential(name, var, model=None):
18281891
"""Add an arbitrary factor potential to the model likelihood
18291892
@@ -1838,7 +1901,7 @@ def Potential(name, var, model=None):
18381901
"""
18391902
model = modelcontext(model)
18401903
var.name = model.name_for(name)
1841-
var.tag.scaling = None
1904+
var.tag.scaling = 1.
18421905
model.potentials.append(var)
18431906
model.add_random_variable(var)
18441907
return var

pymc3/variational/opvi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
12171217
@node_property
12181218
def symbolic_normalizing_constant(self):
12191219
"""*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`"""
1220-
t = self.to_flat_input(at.max([v.scaling for v in self.group]))
1220+
t = self.to_flat_input(at.max([v.tag.scaling for v in self.group]))
12211221
t = self.symbolic_single_sample(t)
12221222
return pm.floatX(t)
12231223

@@ -1370,7 +1370,7 @@ def symbolic_normalizing_constant(self):
13701370
"""
13711371
t = at.max(
13721372
self.collect("symbolic_normalizing_constant")
1373-
+ [var.scaling for var in self.model.observed_RVs]
1373+
+ [var.tag.scaling for var in self.model.observed_RVs]
13741374
)
13751375
t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype))
13761376
return pm.floatX(t)

0 commit comments

Comments
 (0)