File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -42,6 +42,32 @@ def _as_tensor_shape_variable(var):
42
42
return res
43
43
44
44
45
+ def _as_tensor_shape_variable (var ):
46
+ """ Just a collection of useful shape stuff from
47
+ `_infer_ndim_bcast` """
48
+
49
+ if var is None :
50
+ return T .constant ([], dtype = 'int64' )
51
+
52
+ res = var
53
+ if isinstance (res , (tuple , list )):
54
+ if len (res ) == 0 :
55
+ return T .constant ([], dtype = 'int64' )
56
+ res = T .as_tensor_variable (res , ndim = 1 )
57
+
58
+ else :
59
+ if res .ndim != 1 :
60
+ raise TypeError ("shape must be a vector or list of scalar, got\
61
+ '%s'" % res )
62
+
63
+ if (not (res .dtype .startswith ('int' ) or
64
+ res .dtype .startswith ('uint' ))):
65
+
66
+ raise TypeError ('shape must be an integer vector or list' ,
67
+ res .dtype )
68
+ return res
69
+
70
+
45
71
class Distribution (object ):
46
72
"""Statistical distribution"""
47
73
def __new__ (cls , name , * args , ** kwargs ):
You can’t perform that action at this time.
0 commit comments