31
31
from ..step_methods .metropolis import MultivariateNormalProposal
32
32
from ..backends .ndarray import NDArray
33
33
from ..backends .base import MultiTrace
34
- from ..util import is_transformed_name
35
34
36
35
EXPERIMENTAL_WARNING = (
37
36
"Warning: SMC-ABC methods are experimental step methods and not yet"
@@ -53,7 +52,7 @@ def __init__(
53
52
threshold = 0.5 ,
54
53
epsilon = 1.0 ,
55
54
dist_func = "absolute_error" ,
56
- sum_stat = False ,
55
+ sum_stat = "Identity" ,
57
56
progressbar = False ,
58
57
model = None ,
59
58
random_seed = - 1 ,
@@ -140,6 +139,7 @@ def setup_kernel(self):
140
139
self .epsilon ,
141
140
simulator .observations ,
142
141
simulator .distribution .function ,
142
+ [v .name for v in simulator .distribution .params ],
143
143
self .model ,
144
144
self .var_info ,
145
145
self .variables ,
@@ -281,7 +281,7 @@ def mutate(self):
281
281
self .priors [draw ],
282
282
self .likelihoods [draw ],
283
283
draw ,
284
- * parameters
284
+ * parameters ,
285
285
)
286
286
for draw in iterator
287
287
]
@@ -307,7 +307,7 @@ def posterior_to_trace(self):
307
307
size = 0
308
308
for var in varnames :
309
309
shape , new_size = self .var_info [var ]
310
- value .append (self .posterior [i ][size : size + new_size ].reshape (shape ))
310
+ value .append (self .posterior [i ][size : size + new_size ].reshape (shape ))
311
311
size += new_size
312
312
strace .record ({k : v for k , v in zip (varnames , value )})
313
313
return MultiTrace ([strace ])
@@ -389,7 +389,16 @@ class PseudoLikelihood:
389
389
"""
390
390
391
391
def __init__ (
392
- self , epsilon , observations , function , model , var_info , variables , distance , sum_stat
392
+ self ,
393
+ epsilon ,
394
+ observations ,
395
+ function ,
396
+ params ,
397
+ model ,
398
+ var_info ,
399
+ variables ,
400
+ distance ,
401
+ sum_stat ,
393
402
):
394
403
"""
395
404
epsilon: float
@@ -398,34 +407,48 @@ def __init__(
398
407
observed data
399
408
function: python function
400
409
data simulator
410
+ params: list
411
+ names of the variables parameterizing the simulator.
401
412
model: PyMC3 model
402
413
var_info: dict
403
414
generated by ``SMC.initialize_population``
404
- distance: str
405
- Distance function. Available options are ``absolute_error`` (default) and
406
- ``sum_of_squared_distance``.
407
- sum_stat: bool
408
- Whether to use or not a summary statistics.
415
+ distance : str or callable
416
+ Distance function. The only available option is ``gaussian_kernel``
417
+ sum_stat: str or callable
418
+ Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``,
419
+ ``median``. The user can pass any valid Python function
409
420
"""
410
421
self .epsilon = epsilon
411
- self .observations = observations
412
422
self .function = function
423
+ self .params = params
413
424
self .model = model
414
425
self .var_info = var_info
415
426
self .variables = variables
416
427
self .varnames = [v .name for v in self .variables ]
417
428
self .unobserved_RVs = [v .name for v in self .model .unobserved_RVs ]
418
- self .kernel = self .gauss_kernel
419
- self .dist_func = distance
420
- self .sum_stat = sum_stat
421
429
self .get_unobserved_fn = self .model .fastfn (self .model .unobserved_RVs )
422
430
423
- if distance == "absolute_error" :
424
- self .dist_func = self .absolute_error
425
- elif distance == "sum_of_squared_distance" :
426
- self .dist_func = self .sum_of_squared_distance
431
+ if sum_stat == "identity" :
432
+ self .sum_stat = lambda x : x
433
+ elif sum_stat == "sorted" :
434
+ self .sum_stat = np .sort
435
+ elif sum_stat == "mean" :
436
+ self .sum_stat = np .mean
437
+ elif sum_stat == "median" :
438
+ self .sum_stat = np .median
439
+ elif hasattr (sum_stat , "__call__" ):
440
+ self .sum_stat = sum_stat
441
+ else :
442
+ raise ValueError (f"The summary statistics { sum_stat } is not implemented" )
443
+
444
+ self .observations = self .sum_stat (observations )
445
+
446
+ if distance == "gaussian_kernel" :
447
+ self .distance = self .gaussian_kernel
448
+ elif hasattr (distance , "__call__" ):
449
+ self .distance = distance
427
450
else :
428
- raise ValueError ("Distance metric not understood " )
451
+ raise ValueError (f"The distance metric { distance } is not implemented " )
429
452
430
453
def posterior_to_function (self , posterior ):
431
454
model = self .model
@@ -436,32 +459,18 @@ def posterior_to_function(self, posterior):
436
459
size = 0
437
460
for var in self .variables :
438
461
shape , new_size = var_info [var .name ]
439
- varvalues .append (posterior [size : size + new_size ].reshape (shape ))
462
+ varvalues .append (posterior [size : size + new_size ].reshape (shape ))
440
463
size += new_size
441
464
point = {k : v for k , v in zip (self .varnames , varvalues )}
442
465
for varname , value in zip (self .unobserved_RVs , self .get_unobserved_fn (point )):
443
- if not is_transformed_name ( varname ) :
466
+ if varname in self . params :
444
467
samples [varname ] = value
445
468
return samples
446
469
447
- def gauss_kernel (self , value ):
448
- epsilon = self .epsilon
449
- return (- (value ** 2 ) / epsilon ** 2 + np .log (1 / (2 * np .pi * epsilon ** 2 ))) / 2.0
450
-
451
- def absolute_error (self , a , b ):
452
- if self .sum_stat :
453
- return np .abs (a .mean () - b .mean ())
454
- else :
455
- return np .mean (np .atleast_2d (np .abs (a - b )))
456
-
457
- def sum_of_squared_distance (self , a , b ):
458
- if self .sum_stat :
459
- return np .sum (np .atleast_2d ((a .mean () - b .mean ()) ** 2 ))
460
- else :
461
- return np .mean (np .sum (np .atleast_2d ((a - b ) ** 2 )))
470
+ def gaussian_kernel (self , obs_data , sim_data ):
471
+ return np .sum (- 0.5 * ((obs_data - sim_data ) / self .epsilon ) ** 2 )
462
472
463
473
def __call__ (self , posterior ):
464
474
func_parameters = self .posterior_to_function (posterior )
465
- sim_data = self .function (** func_parameters )
466
- value = self .dist_func (self .observations , sim_data )
467
- return self .kernel (value )
475
+ sim_data = self .sum_stat (self .function (** func_parameters ))
476
+ return self .distance (self .observations , sim_data )
0 commit comments