13
13
# limitations under the License.
14
14
import warnings
15
15
16
- from typing import Optional , Tuple , Union
16
+ from typing import Any , Optional , Tuple , Union
17
17
18
18
import aesara
19
19
import aesara .tensor as at
31
31
from aesara .tensor .random .op import RandomVariable
32
32
from aesara .tensor .random .utils import normalize_size_param
33
33
34
- from pymc .aesaraf import change_rv_size , floatX , intX
34
+ from pymc .aesaraf import change_rv_size , convert_observed_data , floatX , intX
35
35
from pymc .distributions import distribution , multivariate
36
36
from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
37
37
from pymc .distributions .dist_math import check_parameters
38
38
from pymc .distributions .distribution import SymbolicDistribution , _moment , moment
39
39
from pymc .distributions .logprob import ignore_logprob , logp
40
- from pymc .distributions .shape_utils import Shape , rv_size_is_none , to_tuple
40
+ from pymc .distributions .shape_utils import (
41
+ Dims ,
42
+ Shape ,
43
+ convert_dims ,
44
+ rv_size_is_none ,
45
+ to_tuple ,
46
+ )
47
+ from pymc .model import modelcontext
41
48
from pymc .util import check_dist_not_registered
42
49
43
50
__all__ = [
50
57
]
51
58
52
59
53
- def get_steps_from_shape (
60
+ def get_steps (
54
61
steps : Optional [Union [int , np .ndarray , TensorVariable ]],
55
- shape : Optional [Shape ],
62
+ * ,
63
+ shape : Optional [Shape ] = None ,
64
+ dims : Optional [Dims ] = None ,
65
+ observed : Optional [Any ] = None ,
56
66
step_shape_offset : int = 0 ,
57
67
):
58
- """Extract number of steps from shape information
68
+ """Extract number of steps from shape / dims / observed information
59
69
60
70
Parameters
61
71
----------
62
72
steps:
63
73
User specified steps for timeseries distribution
64
74
shape:
65
75
User specified shape for timeseries distribution
76
+ dims:
77
+ User specified dims for timeseries distribution
78
+ observed:
79
+ User specified observed data from timeseries distribution
66
80
step_shape_offset:
67
81
Difference between last shape dimension and number of steps in timeseries
68
82
distribution, defaults to 0
69
83
70
- Raises
71
- ------
72
- ValueError
73
- If neither shape nor steps are provided
74
-
75
84
Returns
76
85
-------
77
86
steps
78
87
Steps, if specified directly by user, or inferred from the last dimension of
79
- shape. When both steps and shape are provided, a symbolic Assert is added
80
- to make sure they are consistent.
88
+ shape / dims / observed. When two sources of step information are provided,
89
+ a symbolic Assert is added to ensure they are consistent.
81
90
"""
82
- steps_from_shape = None
91
+ inferred_steps = None
83
92
if shape is not None :
84
93
shape = to_tuple (shape )
85
94
if shape [- 1 ] is not ...:
86
- steps_from_shape = shape [- 1 ] - step_shape_offset
87
- if steps is None :
88
- if steps_from_shape is not None :
89
- steps = steps_from_shape
90
- else :
91
- raise ValueError ("Must specify steps or shape parameter" )
92
- elif steps_from_shape is not None :
93
- # Assert that steps and shape are consistent
94
- steps = Assert (msg = "Steps do not match last shape dimension" )(
95
- steps , at .eq (steps , steps_from_shape )
95
+ inferred_steps = shape [- 1 ] - step_shape_offset
96
+
97
+ if inferred_steps is None and dims is not None :
98
+ dims = convert_dims (dims )
99
+ if dims [- 1 ] is not ...:
100
+ model = modelcontext (None )
101
+ inferred_steps = model .dim_lengths [dims [- 1 ]] - step_shape_offset
102
+
103
+ if inferred_steps is None and observed is not None :
104
+ observed = convert_observed_data (observed )
105
+ inferred_steps = observed .shape [- 1 ] - step_shape_offset
106
+
107
+ if inferred_steps is None :
108
+ inferred_steps = steps
109
+ # If there are two sources of information for the steps, assert they are consistent
110
+ elif steps is not None :
111
+ inferred_steps = Assert (msg = "Steps do not match last shape dimension" )(
112
+ inferred_steps , at .eq (inferred_steps , steps )
96
113
)
97
- return steps
114
+ return inferred_steps
98
115
99
116
100
117
class GaussianRandomWalkRV (RandomVariable ):
@@ -212,26 +229,38 @@ class GaussianRandomWalk(distribution.Continuous):
212
229
213
230
.. warning:: init will be cloned, rendering them independent of the ones passed as input.
214
231
215
- steps : int
216
- Number of steps in Gaussian Random Walks (steps > 0).
232
+ steps : int, optional
233
+ Number of steps in Gaussian Random Walk (steps > 0). Only needed if size is
234
+ used to specify distribution
217
235
"""
218
236
219
237
rv_op = gaussianrandomwalk
220
238
221
- def __new__ (cls , name , mu = 0.0 , sigma = 1.0 , init = None , steps = None , ** kwargs ):
222
- if init is not None :
223
- check_dist_not_registered (init )
224
- return super ().__new__ (cls , name , mu , sigma , init , steps , ** kwargs )
239
+ def __new__ (cls , * args , steps = None , ** kwargs ):
240
+ steps = get_steps (
241
+ steps = steps ,
242
+ shape = None , # Shape will be checked in `cls.dist`
243
+ dims = kwargs .get ("dims" , None ),
244
+ observed = kwargs .get ("observed" , None ),
245
+ step_shape_offset = 1 ,
246
+ )
247
+ return super ().__new__ (cls , * args , steps = steps , ** kwargs )
225
248
226
249
@classmethod
227
250
def dist (
228
- cls , mu = 0.0 , sigma = 1.0 , init = None , steps = None , size = None , ** kwargs
251
+ cls , mu = 0.0 , sigma = 1.0 , * , init = None , steps = None , size = None , ** kwargs
229
252
) -> at .TensorVariable :
230
253
231
254
mu = at .as_tensor_variable (floatX (mu ))
232
255
sigma = at .as_tensor_variable (floatX (sigma ))
233
256
234
- steps = get_steps_from_shape (steps , kwargs .get ("shape" , None ), step_shape_offset = 1 )
257
+ steps = get_steps (
258
+ steps = steps ,
259
+ shape = kwargs .get ("shape" , None ),
260
+ step_shape_offset = 1 ,
261
+ )
262
+ if steps is None :
263
+ raise ValueError ("Must specify steps or shape parameter" )
235
264
steps = at .as_tensor_variable (intX (steps ))
236
265
237
266
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
@@ -245,6 +274,7 @@ def dist(
245
274
and init .owner .op .ndim_supp == 0
246
275
):
247
276
raise TypeError ("init must be a univariate distribution variable" )
277
+ check_dist_not_registered (init )
248
278
249
279
# Ignores logprob of init var because that's accounted for in the logp method
250
280
init = ignore_logprob (init )
@@ -340,6 +370,9 @@ class AR(SymbolicDistribution):
340
370
ar_order: int, optional
341
371
Order of the AR process. Inferred from length of the last dimension of rho, if
342
372
possible. ar_order = rho.shape[-1] if constant else rho.shape[-1] - 1
373
+ steps : int, optional
374
+ Number of steps in AR process (steps > 0). Only needed if size is used to
375
+ specify distribution
343
376
344
377
Notes
345
378
-----
@@ -360,6 +393,15 @@ class AR(SymbolicDistribution):
360
393
361
394
"""
362
395
396
+ def __new__ (cls , * args , steps = None , ** kwargs ):
397
+ steps = get_steps (
398
+ steps = steps ,
399
+ shape = None , # Shape will be checked in `cls.dist`
400
+ dims = kwargs .get ("dims" , None ),
401
+ observed = kwargs .get ("observed" , None ),
402
+ )
403
+ return super ().__new__ (cls , * args , steps = steps , ** kwargs )
404
+
363
405
@classmethod
364
406
def dist (
365
407
cls ,
@@ -384,7 +426,9 @@ def dist(
384
426
)
385
427
init_dist = kwargs ["init" ]
386
428
387
- steps = get_steps_from_shape (steps , kwargs .get ("shape" , None ))
429
+ steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ))
430
+ if steps is None :
431
+ raise ValueError ("Must specify steps or shape parameter" )
388
432
steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
389
433
390
434
if ar_order is None :
0 commit comments