@@ -157,12 +157,12 @@ def get_values(self, varname, burn=0, thin=1):
157
157
"""
158
158
raise NotImplementedError
159
159
160
- def get_sampler_stats (self , varname , sampler_idx = None , burn = 0 , thin = 1 ):
160
+ def get_sampler_stats (self , stat_name , sampler_idx = None , burn = 0 , thin = 1 ):
161
161
"""Get sampler statistics from the trace.
162
162
163
163
Parameters
164
164
----------
165
- varname : str
165
+ stat_name : str
166
166
sampler_idx : int or None
167
167
burn : int
168
168
thin : int
@@ -179,21 +179,21 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
179
179
raise ValueError ("This backend does not support sampler stats" )
180
180
181
181
if sampler_idx is not None :
182
- return self ._get_sampler_stats (varname , sampler_idx , burn , thin )
182
+ return self ._get_sampler_stats (stat_name , sampler_idx , burn , thin )
183
183
184
184
sampler_idxs = [i for i , s in enumerate (self .sampler_vars )
185
- if varname in s ]
185
+ if stat_name in s ]
186
186
if not sampler_idxs :
187
- raise KeyError ("Unknown sampler stat %s" % varname )
187
+ raise KeyError ("Unknown sampler stat %s" % stat_name )
188
188
189
- vals = np .stack ([self ._get_sampler_stats (varname , i , burn , thin )
189
+ vals = np .stack ([self ._get_sampler_stats (stat_name , i , burn , thin )
190
190
for i in sampler_idxs ], axis = - 1 )
191
191
if vals .shape [- 1 ] == 1 :
192
192
return vals [..., 0 ]
193
193
else :
194
194
return vals
195
195
196
- def _get_sampler_stats (self , varname , sampler_idx , burn , thin ):
196
+ def _get_sampler_stats (self , stat_name , sampler_idx , burn , thin ):
197
197
"""Get sampler statistics."""
198
198
raise NotImplementedError ()
199
199
@@ -458,13 +458,13 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
458
458
results = [self ._straces [chains ].get_values (varname , burn , thin )]
459
459
return _squeeze_cat (results , combine , squeeze )
460
460
461
- def get_sampler_stats (self , varname , burn = 0 , thin = 1 , combine = True ,
461
+ def get_sampler_stats (self , stat_name , burn = 0 , thin = 1 , combine = True ,
462
462
chains = None , squeeze = True ):
463
463
"""Get sampler statistics from the trace.
464
464
465
465
Parameters
466
466
----------
467
- varname : str
467
+ stat_name : str
468
468
sampler_idx : int or None
469
469
burn : int
470
470
thin : int
@@ -477,8 +477,8 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
477
477
a numpy array of shape (m, n), where `m` is the number of
478
478
such samplers, and `n` is the number of samples.
479
479
"""
480
- if varname not in self .stat_names :
481
- raise KeyError ("Unknown sampler statistic %s" % varname )
480
+ if stat_name not in self .stat_names :
481
+ raise KeyError ("Unknown sampler statistic %s" % stat_name )
482
482
483
483
if chains is None :
484
484
chains = self .chains
@@ -487,7 +487,7 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
487
487
except TypeError :
488
488
chains = [chains ]
489
489
490
- results = [self ._straces [chain ].get_sampler_stats (varname , None , burn , thin )
490
+ results = [self ._straces [chain ].get_sampler_stats (stat_name , None , burn , thin )
491
491
for chain in chains ]
492
492
return _squeeze_cat (results , combine , squeeze )
493
493
0 commit comments