|
9 | 9 | Model, get_named_nodes_and_relations, FreeRV,
|
10 | 10 | ObservedRV, MultiObservedRV, Context, InitContextMeta
|
11 | 11 | )
|
12 |
| -from ..vartypes import string_types |
| 12 | +from ..vartypes import string_types, theano_constant |
13 | 13 | from .shape_utils import (
|
14 | 14 | to_tuple,
|
15 | 15 | get_broadcastable_dist_samples,
|
@@ -92,7 +92,7 @@ def getattr_value(self, val):
|
92 | 92 | if isinstance(val, tt.TensorVariable):
|
93 | 93 | return val.tag.test_value
|
94 | 94 |
|
95 |
| - if isinstance(val, (tt.TensorConstant, theano.gof.graph.Constant)): |
| 95 | + if isinstance(val, theano_constant): |
96 | 96 | return val.value
|
97 | 97 |
|
98 | 98 | return val
|
@@ -502,8 +502,7 @@ def __init__(self):
|
502 | 502 | def is_fast_drawable(var):
|
503 | 503 | return isinstance(var, (numbers.Number,
|
504 | 504 | np.ndarray,
|
505 |
| - tt.TensorConstant, |
506 |
| - theano.gof.graph.Constant, |
| 505 | + theano_constant, |
507 | 506 | tt.sharedvar.SharedVariable))
|
508 | 507 |
|
509 | 508 |
|
@@ -593,8 +592,7 @@ def draw_values(params, point=None, size=None):
|
593 | 592 | if (next_, size) in drawn:
|
594 | 593 | # If the node already has a givens value, skip it
|
595 | 594 | continue
|
596 |
| - elif isinstance(next_, (tt.TensorConstant, |
597 |
| - theano.gof.graph.Constant, |
| 595 | + elif isinstance(next_, (theano_constant, |
598 | 596 | tt.sharedvar.SharedVariable)):
|
599 | 597 | # If the node is a theano.tensor.TensorConstant or a
|
600 | 598 | # theano.tensor.sharedvar.SharedVariable, its value will be
|
@@ -785,7 +783,7 @@ def _draw_value(param, point=None, givens=None, size=None):
|
785 | 783 | """
|
786 | 784 | if isinstance(param, (numbers.Number, np.ndarray)):
|
787 | 785 | return param
|
788 |
| - elif isinstance(param, (tt.TensorConstant, theano.gof.graph.Constant)): |
| 786 | + elif isinstance(param, theano_constant): |
789 | 787 | return param.value
|
790 | 788 | elif isinstance(param, tt.sharedvar.SharedVariable):
|
791 | 789 | return param.get_value()
|
|
0 commit comments