7
7
8
8
from collections import UserDict
9
9
from contextlib import AbstractContextManager
10
- from typing import (
11
- TYPE_CHECKING ,
12
- Any ,
13
- Callable ,
14
- Dict ,
15
- List ,
16
- Optional ,
17
- Set ,
18
- Tuple ,
19
- Union ,
20
- cast ,
21
- overload ,
22
- )
10
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , cast , overload
23
11
24
12
import numpy as np
25
13
import theano .graph .basic
@@ -69,15 +57,15 @@ class _TraceDict(UserDict):
69
57
~~~~~~~~~~
70
58
varnames: list of strings"""
71
59
72
- varnames : List [str ]
60
+ varnames : list [str ]
73
61
_len : int
74
62
data : Point
75
63
76
64
def __init__ (
77
65
self ,
78
- point_list : Optional [ List [ Point ]] = None ,
79
- multi_trace : Optional [ MultiTrace ] = None ,
80
- dict_ : Optional [ Point ] = None ,
66
+ point_list : list [ Point ] | None = None ,
67
+ multi_trace : MultiTrace | None = None ,
68
+ dict_ : Point | None = None ,
81
69
):
82
70
""""""
83
71
if multi_trace :
@@ -134,11 +122,11 @@ def apply_slice(arr: np.ndarray) -> np.ndarray:
134
122
return _TraceDict (dict_ = sliced_dict )
135
123
136
124
@overload
137
- def __getitem__ (self , item : Union [ str , HasName ] ) -> np .ndarray :
125
+ def __getitem__ (self , item : str | HasName ) -> np .ndarray :
138
126
...
139
127
140
128
@overload
141
- def __getitem__ (self , item : Union [ slice , int ] ) -> _TraceDict :
129
+ def __getitem__ (self , item : slice | int ) -> _TraceDict :
142
130
...
143
131
144
132
def __getitem__ (self , item ):
@@ -155,13 +143,13 @@ def __getitem__(self, item):
155
143
156
144
157
145
def fast_sample_posterior_predictive (
158
- trace : Union [ MultiTrace , Dataset , InferenceData , List [ Dict [str , np .ndarray ] ]],
159
- samples : Optional [ int ] = None ,
160
- model : Optional [ Model ] = None ,
161
- var_names : Optional [ List [ str ]] = None ,
146
+ trace : MultiTrace | Dataset | InferenceData | list [ dict [str , np .ndarray ]],
147
+ samples : int | None = None ,
148
+ model : Model | None = None ,
149
+ var_names : list [ str ] | None = None ,
162
150
keep_size : bool = False ,
163
151
random_seed = None ,
164
- ) -> Dict [str , np .ndarray ]:
152
+ ) -> dict [str , np .ndarray ]:
165
153
"""Generate posterior predictive samples from a model given a trace.
166
154
167
155
This is a vectorized alternative to the standard ``sample_posterior_predictive`` function.
@@ -250,7 +238,7 @@ def fast_sample_posterior_predictive(
250
238
251
239
assert isinstance (_trace , _TraceDict )
252
240
253
- _samples : List [int ] = []
241
+ _samples : list [int ] = []
254
242
# temporary replacement for more complicated logic.
255
243
max_samples : int = len_trace
256
244
if samples is None or samples == max_samples :
@@ -289,7 +277,7 @@ def fast_sample_posterior_predictive(
289
277
_ETPParent = UserDict
290
278
291
279
class _ExtendableTrace (_ETPParent ):
292
- def extend_trace (self , trace : Dict [str , np .ndarray ]) -> None :
280
+ def extend_trace (self , trace : dict [str , np .ndarray ]) -> None :
293
281
for k , v in trace .items ():
294
282
if k in self .data :
295
283
self .data [k ] = np .concatenate ((self .data [k ], v ))
@@ -301,7 +289,7 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
301
289
strace = _trace if s == len_trace else _trace [slice (0 , s )]
302
290
try :
303
291
values = posterior_predictive_draw_values (cast (List [Any ], vars ), strace , s )
304
- new_trace : Dict [str , np .ndarray ] = {k .name : v for (k , v ) in zip (vars , values )}
292
+ new_trace : dict [str , np .ndarray ] = {k .name : v for (k , v ) in zip (vars , values )}
305
293
ppc_trace .extend_trace (new_trace )
306
294
except KeyboardInterrupt :
307
295
pass
@@ -313,8 +301,8 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
313
301
314
302
315
303
def posterior_predictive_draw_values (
316
- vars : List [Any ], trace : _TraceDict , samples : int
317
- ) -> List [np .ndarray ]:
304
+ vars : list [Any ], trace : _TraceDict , samples : int
305
+ ) -> list [np .ndarray ]:
318
306
with _PosteriorPredictiveSampler (vars , trace , samples , None ) as sampler :
319
307
return sampler .draw_values ()
320
308
@@ -323,25 +311,25 @@ class _PosteriorPredictiveSampler(AbstractContextManager):
323
311
"""The process of posterior predictive sampling is quite complicated so this provides a central data store."""
324
312
325
313
# inputs
326
- vars : List [Any ]
314
+ vars : list [Any ]
327
315
trace : _TraceDict
328
316
samples : int
329
- size : Optional [ int ] # not supported!
317
+ size : int | None # not supported!
330
318
331
319
# other slots
332
320
logger : logging .Logger
333
321
334
322
# for the search
335
- evaluated : Dict [int , np .ndarray ]
336
- symbolic_params : List [ Tuple [int , Any ]]
323
+ evaluated : dict [int , np .ndarray ]
324
+ symbolic_params : list [ tuple [int , Any ]]
337
325
338
326
# set by make_graph...
339
- leaf_nodes : Dict [str , Any ]
340
- named_nodes_parents : Dict [str , Any ]
341
- named_nodes_children : Dict [str , Any ]
327
+ leaf_nodes : dict [str , Any ]
328
+ named_nodes_parents : dict [str , Any ]
329
+ named_nodes_children : dict [str , Any ]
342
330
_tok : contextvars .Token
343
331
344
- def __init__ (self , vars , trace : _TraceDict , samples , model : Optional [ Model ] , size = None ):
332
+ def __init__ (self , vars , trace : _TraceDict , samples , model : Model | None , size = None ):
345
333
if size is not None :
346
334
raise NotImplementedError (
347
335
"sample_posterior_predictive does not support the size argument at this time."
@@ -361,7 +349,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
361
349
vectorized_ppc .reset (self ._tok )
362
350
return False
363
351
364
- def draw_values (self ) -> List [np .ndarray ]:
352
+ def draw_values (self ) -> list [np .ndarray ]:
365
353
vars = self .vars
366
354
trace = self .trace
367
355
samples = self .samples
@@ -438,8 +426,8 @@ def draw_values(self) -> List[np.ndarray]:
438
426
# the below makes sure the graph is evaluated in order
439
427
# test_distributions_random::TestDrawValues::test_draw_order fails without it
440
428
# The remaining params that must be drawn are all hashable
441
- to_eval : Set [int ] = set ()
442
- missing_inputs : Set [int ] = {j for j , p in self .symbolic_params }
429
+ to_eval : set [int ] = set ()
430
+ missing_inputs : set [int ] = {j for j , p in self .symbolic_params }
443
431
444
432
while to_eval or missing_inputs :
445
433
if to_eval == missing_inputs :
@@ -477,19 +465,19 @@ def init(self) -> None:
477
465
from the posterior predictive distribution. Notably it initializes the
478
466
``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
479
467
parts of the model."""
480
- vars : List [Any ] = self .vars
468
+ vars : list [Any ] = self .vars
481
469
trace : _TraceDict = self .trace
482
470
samples : int = self .samples
483
- leaf_nodes : Dict [str , Any ]
484
- named_nodes_parents : Dict [str , Any ]
485
- named_nodes_children : Dict [str , Any ]
471
+ leaf_nodes : dict [str , Any ]
472
+ named_nodes_parents : dict [str , Any ]
473
+ named_nodes_children : dict [str , Any ]
486
474
487
475
# initialization phase
488
476
context = _DrawValuesContext .get_context ()
489
477
assert isinstance (context , _DrawValuesContext )
490
478
with context :
491
479
drawn = context .drawn_vars
492
- evaluated : Dict [int , Any ] = {}
480
+ evaluated : dict [int , Any ] = {}
493
481
symbolic_params = []
494
482
for i , var in enumerate (vars ):
495
483
if is_fast_drawable (var ):
@@ -534,7 +522,7 @@ def make_graph(self) -> None:
534
522
else :
535
523
self .named_nodes_children [k ].update (nnc [k ])
536
524
537
- def draw_value (self , param , trace : Optional [ _TraceDict ] = None , givens = None ):
525
+ def draw_value (self , param , trace : _TraceDict | None = None , givens = None ):
538
526
"""Draw a set of random values from a distribution or return a constant.
539
527
540
528
Parameters
@@ -559,7 +547,7 @@ def random_sample(
559
547
param ,
560
548
point : _TraceDict ,
561
549
size : int ,
562
- shape : Tuple [int , ...],
550
+ shape : tuple [int , ...],
563
551
) -> np .ndarray :
564
552
val = meth (point = point , size = size )
565
553
try :
@@ -591,7 +579,7 @@ def random_sample(
591
579
elif hasattr (param , "random" ) and param .random is not None :
592
580
model = modelcontext (None )
593
581
assert isinstance (model , Model )
594
- shape : Tuple [int , ...] = tuple (_param_shape (param , model ))
582
+ shape : tuple [int , ...] = tuple (_param_shape (param , model ))
595
583
return random_sample (param .random , param , point = trace , size = samples , shape = shape )
596
584
elif (
597
585
hasattr (param , "distribution" )
@@ -602,7 +590,7 @@ def random_sample(
602
590
# shape inspection for ObservedRV
603
591
dist_tmp = param .distribution
604
592
try :
605
- distshape : Tuple [int , ...] = tuple (param .observations .shape .eval ())
593
+ distshape : tuple [int , ...] = tuple (param .observations .shape .eval ())
606
594
except AttributeError :
607
595
distshape = tuple (param .observations .shape )
608
596
@@ -689,7 +677,7 @@ def random_sample(
689
677
raise ValueError ("Unexpected type in draw_value: %s" % type (param ))
690
678
691
679
692
- def _param_shape (var_desig , model : Model ) -> Tuple [int , ...]:
680
+ def _param_shape (var_desig , model : Model ) -> tuple [int , ...]:
693
681
if isinstance (var_desig , str ):
694
682
v = model [var_desig ]
695
683
else :
0 commit comments