Skip to content

Commit 98635eb

Browse files
committed
Added theano.gof.graph.Constant into the type checks done in distributions.py.
1 parent 879cb49 commit 98635eb

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)
2424
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
2525
- Wrapped `DensityDist.rand` with `generate_samples` to make it aware of the distribution's shape. Added control flow attributes to still be able to behave as in earlier versions, and to control how to interpret the `size` parameter in the `random` callable signature. Fixes [3553](https://github.com/pymc-devs/pymc3/issues/3553)
26+
- Added `theano.gof.graph.Constant` to type checks done in `_draw_value` (fixes issue [3595](https://github.com/pymc-devs/pymc3/issues/3595))
2627

2728

2829
## PyMC3 3.7 (May 29 2019)

pymc3/distributions/distribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def getattr_value(self, val):
9292
if isinstance(val, tt.TensorVariable):
9393
return val.tag.test_value
9494

95-
if isinstance(val, tt.TensorConstant):
95+
if isinstance(val, (tt.TensorConstant, theano.gof.graph.Constant)):
9696
return val.value
9797

9898
return val
@@ -503,6 +503,7 @@ def is_fast_drawable(var):
503503
return isinstance(var, (numbers.Number,
504504
np.ndarray,
505505
tt.TensorConstant,
506+
theano.gof.graph.Constant,
506507
tt.sharedvar.SharedVariable))
507508

508509

@@ -593,6 +594,7 @@ def draw_values(params, point=None, size=None):
593594
# If the node already has a givens value, skip it
594595
continue
595596
elif isinstance(next_, (tt.TensorConstant,
597+
theano.gof.graph.Constant,
596598
tt.sharedvar.SharedVariable)):
597599
# If the node is a theano.tensor.TensorConstant or a
598600
# theano.tensor.sharedvar.SharedVariable, its value will be
@@ -783,7 +785,7 @@ def _draw_value(param, point=None, givens=None, size=None):
783785
"""
784786
if isinstance(param, (numbers.Number, np.ndarray)):
785787
return param
786-
elif isinstance(param, tt.TensorConstant):
788+
elif isinstance(param, (tt.TensorConstant, theano.gof.graph.Constant)):
787789
return param.value
788790
elif isinstance(param, tt.sharedvar.SharedVariable):
789791
return param.get_value()

pymc3/tests/test_random.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def test_dep_vars(self):
9191
np.all(val1 != val4), np.all(val2 != val3),
9292
np.all(val2 != val4), np.all(val3 != val4)])
9393

94+
def test_gof_constant(self):
95+
# Issue 3595 pointed out that slice(None) can introduce
96+
# theano.gof.graph.Constant into the compute graph, which wasn't
97+
# handled correctly by draw_values
98+
n_d = 500
99+
n_x = 2
100+
n_y = 1
101+
n_g = 10
102+
g = np.random.randint(0, n_g, (n_d,)) # group
103+
x = np.random.randint(0, n_x, (n_d,)) # x factor
104+
with pm.Model():
105+
multi_dim_rv = pm.Normal('multi_dim_rv', mu=0, sd=1, shape=(n_x, n_g, n_y))
106+
indexed_rv = multi_dim_rv[x, g, :]
107+
i = draw_values([indexed_rv])
108+
assert i is not None
109+
94110

95111
class TestJointDistributionDrawValues(SeededTest):
96112
def test_joint_distribution(self):

0 commit comments

Comments
 (0)