Skip to content

Commit 5746788

Browse files
authored
Merge pull request #3599 from lucianopaz/iss3595
Added theano.gof.graph.Constant into the type checks done in distributions.py
2 parents 219652c + 88300b5 commit 5746788

File tree

4 files changed

+49
-23
lines changed

4 files changed

+49
-23
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)
2525
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
2626
- Wrapped `DensityDist.rand` with `generate_samples` to make it aware of the distribution's shape. Added control flow attributes to still be able to behave as in earlier versions, and to control how to interpret the `size` parameter in the `random` callable signature. Fixes [3553](https://github.com/pymc-devs/pymc3/issues/3553)
27+
- Added `theano.gof.graph.Constant` to type checks done in `_draw_value` (fixes issue [3595](https://github.com/pymc-devs/pymc3/issues/3595))
2728

2829

2930
## PyMC3 3.7 (May 29 2019)

pymc3/distributions/distribution.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Model, get_named_nodes_and_relations, FreeRV,
1010
ObservedRV, MultiObservedRV, Context, InitContextMeta
1111
)
12-
from ..vartypes import string_types
12+
from ..vartypes import string_types, theano_constant
1313
from .shape_utils import (
1414
to_tuple,
1515
get_broadcastable_dist_samples,
@@ -92,7 +92,7 @@ def getattr_value(self, val):
9292
if isinstance(val, tt.TensorVariable):
9393
return val.tag.test_value
9494

95-
if isinstance(val, tt.TensorConstant):
95+
if isinstance(val, theano_constant):
9696
return val.value
9797

9898
return val
@@ -502,7 +502,7 @@ def __init__(self):
502502
def is_fast_drawable(var):
503503
return isinstance(var, (numbers.Number,
504504
np.ndarray,
505-
tt.TensorConstant,
505+
theano_constant,
506506
tt.sharedvar.SharedVariable))
507507

508508

@@ -592,7 +592,7 @@ def draw_values(params, point=None, size=None):
592592
if (next_, size) in drawn:
593593
# If the node already has a givens value, skip it
594594
continue
595-
elif isinstance(next_, (tt.TensorConstant,
595+
elif isinstance(next_, (theano_constant,
596596
tt.sharedvar.SharedVariable)):
597597
# If the node is a theano.tensor.TensorConstant or a
598598
# theano.tensor.sharedvar.SharedVariable, its value will be
@@ -783,7 +783,7 @@ def _draw_value(param, point=None, givens=None, size=None):
783783
"""
784784
if isinstance(param, (numbers.Number, np.ndarray)):
785785
return param
786-
elif isinstance(param, tt.TensorConstant):
786+
elif isinstance(param, theano_constant):
787787
return param.value
788788
elif isinstance(param, tt.sharedvar.SharedVariable):
789789
return param.get_value()

pymc3/tests/test_random.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def test_dep_vars(self):
9191
np.all(val1 != val4), np.all(val2 != val3),
9292
np.all(val2 != val4), np.all(val3 != val4)])
9393

94+
def test_gof_constant(self):
95+
# Issue 3595 pointed out that slice(None) can introduce
96+
# theano.gof.graph.Constant into the compute graph, which wasn't
97+
# handled correctly by draw_values
98+
n_d = 500
99+
n_x = 2
100+
n_y = 1
101+
n_g = 10
102+
g = np.random.randint(0, n_g, (n_d,)) # group
103+
x = np.random.randint(0, n_x, (n_d,)) # x factor
104+
with pm.Model():
105+
multi_dim_rv = pm.Normal('multi_dim_rv', mu=0, sd=1, shape=(n_x, n_g, n_y))
106+
indexed_rv = multi_dim_rv[x, g, :]
107+
i = draw_values([indexed_rv])
108+
assert i is not None
109+
94110

95111
class TestJointDistributionDrawValues(SeededTest):
96112
def test_joint_distribution(self):

pymc3/vartypes.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1-
__all__ = ['bool_types', 'int_types', 'float_types', 'complex_types', 'continuous_types',
2-
'discrete_types', 'typefilter', 'isgenerator']
3-
4-
bool_types = set(['int8'])
5-
6-
int_types = set(['int8',
7-
'int16',
8-
'int32',
9-
'int64',
10-
'uint8',
11-
'uint16',
12-
'uint32',
13-
'uint64'])
14-
float_types = set(['float32',
15-
'float64'])
16-
complex_types = set(['complex64',
17-
'complex128'])
1+
from theano.tensor import Constant as tensor_constant
2+
from theano.gof.graph import Constant as graph_constant
3+
4+
5+
__all__ = [
6+
"bool_types",
7+
"int_types",
8+
"float_types",
9+
"complex_types",
10+
"continuous_types",
11+
"discrete_types",
12+
"typefilter",
13+
"isgenerator",
14+
"theano_constant",
15+
]
16+
17+
bool_types = set(["int8"])
18+
19+
int_types = set(
20+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
21+
)
22+
float_types = set(["float32", "float64"])
23+
complex_types = set(["complex64", "complex128"])
1824
continuous_types = float_types | complex_types
1925
discrete_types = bool_types | int_types
2026

@@ -27,4 +33,7 @@ def typefilter(vars, types):
2733

2834

2935
def isgenerator(obj):
30-
return hasattr(obj, '__next__')
36+
return hasattr(obj, "__next__")
37+
38+
39+
theano_constant = (tensor_constant, graph_constant)

0 commit comments

Comments
 (0)