13
13
from ..model import modelcontext
14
14
from .report import SamplerReport , merge_reports
15
15
16
- logger = logging .getLogger (' pymc3' )
16
+ logger = logging .getLogger (" pymc3" )
17
17
18
18
19
19
class BackendError (Exception ):
@@ -58,10 +58,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
58
58
test_point_ .update (test_point )
59
59
test_point = test_point_
60
60
var_values = list (zip (self .varnames , self .fn (test_point )))
61
- self .var_shapes = {var : value .shape
62
- for var , value in var_values }
63
- self .var_dtypes = {var : value .dtype
64
- for var , value in var_values }
61
+ self .var_shapes = {var : value .shape for var , value in var_values }
62
+ self .var_dtypes = {var : value .dtype for var , value in var_values }
65
63
self .chain = None
66
64
self ._is_base_setup = False
67
65
self .sampler_vars = None
@@ -87,8 +85,9 @@ def _set_sampler_vars(self, sampler_vars):
87
85
for stats in sampler_vars :
88
86
for key , dtype in stats .items ():
89
87
if dtypes .setdefault (key , dtype ) != dtype :
90
- raise ValueError ("Sampler statistic %s appears with "
91
- "different types." % key )
88
+ raise ValueError (
89
+ "Sampler statistic %s appears with " "different types." % key
90
+ )
92
91
93
92
self .sampler_vars = sampler_vars
94
93
@@ -137,7 +136,7 @@ def __getitem__(self, idx):
137
136
try :
138
137
return self .point (int (idx ))
139
138
except (ValueError , TypeError ): # Passed variable or variable name.
140
- raise ValueError (' Can only index with slice or integer' )
139
+ raise ValueError (" Can only index with slice or integer" )
141
140
142
141
def __len__ (self ):
143
142
raise NotImplementedError
@@ -181,13 +180,14 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
181
180
if sampler_idx is not None :
182
181
return self ._get_sampler_stats (varname , sampler_idx , burn , thin )
183
182
184
- sampler_idxs = [i for i , s in enumerate (self .sampler_vars )
185
- if varname in s ]
183
+ sampler_idxs = [i for i , s in enumerate (self .sampler_vars ) if varname in s ]
186
184
if not sampler_idxs :
187
185
raise KeyError ("Unknown sampler stat %s" % varname )
188
186
189
- vals = np .stack ([self ._get_sampler_stats (varname , i , burn , thin )
190
- for i in sampler_idxs ], axis = - 1 )
187
+ vals = np .stack (
188
+ [self ._get_sampler_stats (varname , i , burn , thin ) for i in sampler_idxs ],
189
+ axis = - 1 ,
190
+ )
191
191
if vals .shape [- 1 ] == 1 :
192
192
return vals [..., 0 ]
193
193
else :
@@ -267,13 +267,14 @@ def __init__(self, straces):
267
267
268
268
self ._report = SamplerReport ()
269
269
for strace in straces :
270
- if hasattr (strace , ' _warnings' ):
270
+ if hasattr (strace , " _warnings" ):
271
271
self ._report ._add_warnings (strace ._warnings , strace .chain )
272
272
273
273
def __repr__ (self ):
274
- template = '<{}: {} chains, {} iterations, {} variables>'
275
- return template .format (self .__class__ .__name__ ,
276
- self .nchains , len (self ), len (self .varnames ))
274
+ template = "<{}: {} chains, {} iterations, {} variables>"
275
+ return template .format (
276
+ self .__class__ .__name__ , self .nchains , len (self ), len (self .varnames )
277
+ )
277
278
278
279
@property
279
280
def nchains (self ):
@@ -310,16 +311,26 @@ def __getitem__(self, idx):
310
311
var = str (var )
311
312
if var in self .varnames :
312
313
if var in self .stat_names :
313
- warnings .warn ("Attribute access on a trace object is ambigous. "
314
- "Sampler statistic and model variable share a name. Use "
315
- "trace.get_values or trace.get_sampler_stats." )
314
+ warnings .warn (
315
+ "Attribute access on a trace object is ambigous. "
316
+ "Sampler statistic and model variable share a name. Use "
317
+ "trace.get_values or trace.get_sampler_stats."
318
+ )
316
319
return self .get_values (var , burn = burn , thin = thin )
317
320
if var in self .stat_names :
318
321
return self .get_sampler_stats (var , burn = burn , thin = thin )
319
322
raise KeyError ("Unknown variable %s" % var )
320
323
321
- _attrs = set (['_straces' , 'varnames' , 'chains' , 'stat_names' ,
322
- 'supports_sampler_stats' , '_report' ])
324
+ _attrs = set (
325
+ [
326
+ "_straces" ,
327
+ "varnames" ,
328
+ "chains" ,
329
+ "stat_names" ,
330
+ "supports_sampler_stats" ,
331
+ "_report" ,
332
+ ]
333
+ )
323
334
324
335
def __getattr__ (self , name ):
325
336
# Avoid infinite recursion when called before __init__
@@ -330,14 +341,17 @@ def __getattr__(self, name):
330
341
name = str (name )
331
342
if name in self .varnames :
332
343
if name in self .stat_names :
333
- warnings .warn ("Attribute access on a trace object is ambigous. "
334
- "Sampler statistic and model variable share a name. Use "
335
- "trace.get_values or trace.get_sampler_stats." )
344
+ warnings .warn (
345
+ "Attribute access on a trace object is ambigous. "
346
+ "Sampler statistic and model variable share a name. Use "
347
+ "trace.get_values or trace.get_sampler_stats."
348
+ )
336
349
return self .get_values (name )
337
350
if name in self .stat_names :
338
351
return self .get_sampler_stats (name )
339
- raise AttributeError ("'{}' object has no attribute '{}'" .format (
340
- type (self ).__name__ , name ))
352
+ raise AttributeError (
353
+ "'{}' object has no attribute '{}'" .format (type (self ).__name__ , name )
354
+ )
341
355
342
356
def __len__ (self ):
343
357
chain = self .chains [- 1 ]
@@ -392,10 +406,12 @@ def add_values(self, vals, overwrite=False):
392
406
l_samples = len (self ) * len (self .chains )
393
407
l_v = len (v )
394
408
if l_v != l_samples :
395
- warnings .warn ("The length of the values you are trying to "
396
- "add ({}) does not match the number ({}) of "
397
- "total samples in the trace "
398
- "(chains * iterations)" .format (l_v , l_samples ))
409
+ warnings .warn (
410
+ "The length of the values you are trying to "
411
+ "add ({}) does not match the number ({}) of "
412
+ "total samples in the trace "
413
+ "(chains * iterations)" .format (l_v , l_samples )
414
+ )
399
415
400
416
v = np .squeeze (v .reshape (len (chains ), len (self ), - 1 ))
401
417
@@ -424,8 +440,9 @@ def remove_values(self, name):
424
440
chain .vars .remove (va )
425
441
del chain .samples [name ]
426
442
427
- def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None ,
428
- squeeze = True ):
443
+ def get_values (
444
+ self , varname , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True
445
+ ):
429
446
"""Get values from traces.
430
447
431
448
Parameters
@@ -452,14 +469,16 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
452
469
chains = self .chains
453
470
varname = str (varname )
454
471
try :
455
- results = [self ._straces [chain ].get_values (varname , burn , thin )
456
- for chain in chains ]
472
+ results = [
473
+ self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains
474
+ ]
457
475
except TypeError : # Single chain passed.
458
476
results = [self ._straces [chains ].get_values (varname , burn , thin )]
459
477
return _squeeze_cat (results , combine , squeeze )
460
478
461
- def get_sampler_stats (self , varname , burn = 0 , thin = 1 , combine = True ,
462
- chains = None , squeeze = True ):
479
+ def get_sampler_stats (
480
+ self , varname , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True
481
+ ):
463
482
"""Get sampler statistics from the trace.
464
483
465
484
Parameters
@@ -487,8 +506,10 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
487
506
except TypeError :
488
507
chains = [chains ]
489
508
490
- results = [self ._straces [chain ].get_sampler_stats (varname , None , burn , thin )
491
- for chain in chains ]
509
+ results = [
510
+ self ._straces [chain ].get_sampler_stats (varname , None , burn , thin )
511
+ for chain in chains
512
+ ]
492
513
return _squeeze_cat (results , combine , squeeze )
493
514
494
515
def _slice (self , slice ):
0 commit comments