1
+ import warnings
1
2
from collections .abc import Sequence
2
3
from copy import copy
3
4
from typing import cast
28
29
from pytensor .tensor .shape import shape_tuple
29
30
from pytensor .tensor .type import TensorType , all_dtypes
30
31
from pytensor .tensor .type_other import NoneConst
32
+ from pytensor .tensor .utils import _parse_gufunc_signature , safe_signature
31
33
from pytensor .tensor .variable import TensorVariable
32
34
33
35
@@ -42,61 +44,81 @@ class RandomVariable(Op):
42
44
43
45
_output_type_depends_on_input_value = True
44
46
45
- __props__ = ("name" , "ndim_supp" , "ndims_params " , "dtype" , "inplace" )
47
+ __props__ = ("name" , "signature " , "dtype" , "inplace" )
46
48
default_output = 1
47
49
48
50
def __init__ (
49
51
self ,
50
52
name = None ,
51
53
ndim_supp = None ,
52
54
ndims_params = None ,
53
- dtype = None ,
55
+ dtype : str | None = None ,
54
56
inplace = None ,
57
+ signature : str | None = None ,
55
58
):
56
59
"""Create a random variable `Op`.
57
60
58
61
Parameters
59
62
----------
60
63
name: str
61
64
The `Op`'s display name.
62
- ndim_supp: int
63
- Total number of dimensions for a single draw of the random variable
64
- (e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
65
- ndims_params: list of int
66
- Number of dimensions for each distribution parameter when the
67
- parameters only specify a single drawn of the random variable
68
- (e.g. a multivariate normal's mean is 1D and covariance is 2D, so
69
- ``ndims_params = [1, 2]``).
65
+ signature: str
66
+ Numpy-like vectorized signature of the random variable.
70
67
dtype: str (optional)
71
68
The dtype of the sampled output. If the value ``"floatX"`` is
72
69
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
73
70
``None`` (the default), the `dtype` keyword must be set when
74
71
`RandomVariable.make_node` is called.
75
72
inplace: boolean (optional)
76
- Determine whether or not the underlying rng state is updated
77
- in-place or not (i.e. copied).
73
+ Determine whether the underlying rng state is mutated or copied.
78
74
79
75
"""
80
76
super ().__init__ ()
81
77
82
78
self .name = name or getattr (self , "name" )
83
- self .ndim_supp = (
84
- ndim_supp if ndim_supp is not None else getattr (self , "ndim_supp" )
79
+
80
+ ndim_supp = (
81
+ ndim_supp if ndim_supp is not None else getattr (self , "ndim_supp" , None )
85
82
)
86
- self .ndims_params = (
87
- ndims_params if ndims_params is not None else getattr (self , "ndims_params" )
83
+ if ndim_supp is not None :
84
+ warnings .warn (
85
+ "ndim_supp is deprecated. Provide signature instead." , FutureWarning
86
+ )
87
+ self .ndim_supp = ndim_supp
88
+ ndims_params = (
89
+ ndims_params
90
+ if ndims_params is not None
91
+ else getattr (self , "ndims_params" , None )
88
92
)
93
+ if ndims_params is not None :
94
+ warnings .warn (
95
+ "ndims_params is deprecated. Provide signature instead." , FutureWarning
96
+ )
97
+ if not isinstance (ndims_params , Sequence ):
98
+ raise TypeError ("Parameter ndims_params must be sequence type." )
99
+ self .ndims_params = tuple (ndims_params )
100
+
101
+ self .signature = signature or getattr (self , "signature" , None )
102
+ if self .signature is not None :
103
+ # Assume a single output. Several methods need to be updated to handle multiple outputs.
104
+ self .inputs_sig , [self .output_sig ] = _parse_gufunc_signature (self .signature )
105
+ self .ndims_params = [len (input_sig ) for input_sig in self .inputs_sig ]
106
+ self .ndim_supp = len (self .output_sig )
107
+ else :
108
+ if (
109
+ getattr (self , "ndim_supp" , None ) is None
110
+ or getattr (self , "ndims_params" , None ) is None
111
+ ):
112
+ raise ValueError ("signature must be provided" )
113
+ else :
114
+ self .signature = safe_signature (self .ndims_params , [self .ndim_supp ])
115
+
89
116
self .dtype = dtype or getattr (self , "dtype" , None )
90
117
91
118
self .inplace = (
92
119
inplace if inplace is not None else getattr (self , "inplace" , False )
93
120
)
94
121
95
- if not isinstance (self .ndims_params , Sequence ):
96
- raise TypeError ("Parameter ndims_params must be sequence type." )
97
-
98
- self .ndims_params = tuple (self .ndims_params )
99
-
100
122
if self .inplace :
101
123
self .destroy_map = {0 : [0 ]}
102
124
@@ -120,16 +142,56 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
120
142
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
121
143
might have `support_shape=(steps,)`.
122
144
"""
145
+ if self .signature is not None :
146
+ # Signature could indicate fixed numerical shapes
147
+ # As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
148
+ output_sig = self .output_sig
149
+ core_out_shape = {
150
+ dim : int (dim ) if str .isnumeric (dim ) else None for dim in self .output_sig
151
+ }
152
+
153
+ # Try to infer missing support dims from signature of params
154
+ for param , param_sig , ndim_params in zip (
155
+ dist_params , self .inputs_sig , self .ndims_params
156
+ ):
157
+ if ndim_params == 0 :
158
+ continue
159
+ for param_dim , dim in zip (param .shape [- ndim_params :], param_sig ):
160
+ if dim in core_out_shape and core_out_shape [dim ] is None :
161
+ core_out_shape [dim ] = param_dim
162
+
163
+ if all (dim is not None for dim in core_out_shape .values ()):
164
+ # We have all we need
165
+ return [core_out_shape [dim ] for dim in output_sig ]
166
+
123
167
raise NotImplementedError (
124
- "`_supp_shape_from_params` must be implemented for multivariate RVs"
168
+ "`_supp_shape_from_params` must be implemented for multivariate RVs "
169
+ "when signature is not sufficient to infer the support shape"
125
170
)
126
171
127
172
def rng_fn (self , rng , * args , ** kwargs ) -> int | float | np .ndarray :
128
173
"""Sample a numeric random variate."""
129
174
return getattr (rng , self .name )(* args , ** kwargs )
130
175
131
176
def __str__ (self ):
132
- props_str = ", " .join (f"{ getattr (self , prop )} " for prop in self .__props__ [1 :])
177
+ # Only show signature from core props
178
+ if signature := self .signature :
179
+ # inp, out = signature.split("->")
180
+ # extended_signature = f"[rng],[size],{inp}->[rng],{out}"
181
+ # core_props = [extended_signature]
182
+ core_props = [f'"{ signature } "' ]
183
+ else :
184
+ # Far back compat
185
+ core_props = [str (self .ndim_supp ), str (self .ndims_params )]
186
+
187
+ # Add any extra props that the subclass may have
188
+ extra_props = [
189
+ str (getattr (self , prop ))
190
+ for prop in self .__props__
191
+ if prop not in RandomVariable .__props__
192
+ ]
193
+
194
+ props_str = ", " .join (core_props + extra_props )
133
195
return f"{ self .name } _rv{{{ props_str } }}"
134
196
135
197
def _infer_shape (
@@ -298,11 +360,11 @@ def make_node(self, rng, size, dtype, *dist_params):
298
360
dtype_idx = constant (all_dtypes .index (dtype ), dtype = "int64" )
299
361
else :
300
362
dtype_idx = constant (dtype , dtype = "int64" )
301
- dtype = all_dtypes [dtype_idx .data ]
302
363
303
- outtype = TensorType ( dtype = dtype , shape = static_shape )
304
- out_var = outtype ()
364
+ dtype = all_dtypes [ dtype_idx . data ]
365
+
305
366
inputs = (rng , size , dtype_idx , * dist_params )
367
+ out_var = TensorType (dtype = dtype , shape = static_shape )()
306
368
outputs = (rng .type (), out_var )
307
369
308
370
return Apply (self , inputs , outputs )
@@ -395,9 +457,8 @@ def vectorize_random_variable(
395
457
# We extend it to accommodate the new input batch dimensions.
396
458
# Otherwise, we assume the new size already has the right values
397
459
398
- # Need to make parameters implicit broadcasting explicit
399
- original_dist_params = node .inputs [3 :]
400
- old_size = node .inputs [1 ]
460
+ original_dist_params = op .dist_params (node )
461
+ old_size = op .size_param (node )
401
462
len_old_size = get_vector_length (old_size )
402
463
403
464
original_expanded_dist_params = explicit_expand_dims (
0 commit comments