@@ -1165,6 +1165,7 @@ def register_rv(
1165
1165
name = self .name_for (name )
1166
1166
rv_var .name = name
1167
1167
rv_var .tag .total_size = total_size
1168
+ rv_var .tag .scaling = _get_scaling (total_size , shape = rv_var .shape , ndim = rv_var .ndim )
1168
1169
1169
1170
# Associate previously unknown dimension names with
1170
1171
# the length of the corresponding RV dimension.
@@ -1824,6 +1825,68 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
1824
1825
return var
1825
1826
1826
1827
1828
+ def _get_scaling (total_size , shape , ndim ):
1829
+ """
1830
+ Gets scaling constant for logp
1831
+ Parameters
1832
+ ----------
1833
+ total_size: int or list[int]
1834
+ shape: shape
1835
+ shape to scale
1836
+ ndim: int
1837
+ ndim hint
1838
+ Returns
1839
+ -------
1840
+ scalar
1841
+ """
1842
+ if total_size is None :
1843
+ coef = 1.
1844
+ elif isinstance (total_size , int ):
1845
+ if ndim >= 1 :
1846
+ denom = shape [0 ]
1847
+ else :
1848
+ denom = 1
1849
+ coef = total_size / denom
1850
+ elif isinstance (total_size , (list , tuple )):
1851
+ if not all (isinstance (i , int ) for i in total_size if (i is not Ellipsis and i is not None )):
1852
+ raise TypeError (
1853
+ "Unrecognized `total_size` type, expected "
1854
+ "int or list of ints, got %r" % total_size
1855
+ )
1856
+ if Ellipsis in total_size :
1857
+ sep = total_size .index (Ellipsis )
1858
+ begin = total_size [:sep ]
1859
+ end = total_size [sep + 1 :]
1860
+ if Ellipsis in end :
1861
+ raise ValueError (
1862
+ "Double Ellipsis in `total_size` is restricted, got %r" % total_size
1863
+ )
1864
+ else :
1865
+ begin = total_size
1866
+ end = []
1867
+ if (len (begin ) + len (end )) > ndim :
1868
+ raise ValueError (
1869
+ "Length of `total_size` is too big, "
1870
+ "number of scalings is bigger that ndim, got %r" % total_size
1871
+ )
1872
+ elif (len (begin ) + len (end )) == 0 :
1873
+ coef = 1.
1874
+ if len (end ) > 0 :
1875
+ shp_end = shape [- len (end ) :]
1876
+ else :
1877
+ shp_end = np .asarray ([])
1878
+ shp_begin = shape [: len (begin )]
1879
+ begin_coef = [t / shp_begin [i ] for i , t in enumerate (begin ) if t is not None ]
1880
+ end_coef = [t / shp_end [i ] for i , t in enumerate (end ) if t is not None ]
1881
+ coefs = begin_coef + end_coef
1882
+ coef = at .prod (coefs )
1883
+ else :
1884
+ raise TypeError (
1885
+ "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
1886
+ )
1887
+ return at .as_tensor (coef , dtype = aesara .config .floatX )
1888
+
1889
+
1827
1890
def Potential (name , var , model = None ):
1828
1891
"""Add an arbitrary factor potential to the model likelihood
1829
1892
@@ -1838,7 +1901,7 @@ def Potential(name, var, model=None):
1838
1901
"""
1839
1902
model = modelcontext (model )
1840
1903
var .name = model .name_for (name )
1841
- var .tag .scaling = None
1904
+ var .tag .scaling = 1.
1842
1905
model .potentials .append (var )
1843
1906
model .add_random_variable (var )
1844
1907
return var
0 commit comments