Skip to content

Commit a72f112

Browse files
bsipoczjunpenglao
authored andcommitted
Rename varname to stat_name for clarity (#3418)
1 parent f113178 commit a72f112

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pymc3/backends/base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ def get_values(self, varname, burn=0, thin=1):
157157
"""
158158
raise NotImplementedError
159159

160-
def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
160+
def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
161161
"""Get sampler statistics from the trace.
162162
163163
Parameters
164164
----------
165-
varname : str
165+
stat_name : str
166166
sampler_idx : int or None
167167
burn : int
168168
thin : int
@@ -179,21 +179,21 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
179179
raise ValueError("This backend does not support sampler stats")
180180

181181
if sampler_idx is not None:
182-
return self._get_sampler_stats(varname, sampler_idx, burn, thin)
182+
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
183183

184184
sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
185-
if varname in s]
185+
if stat_name in s]
186186
if not sampler_idxs:
187-
raise KeyError("Unknown sampler stat %s" % varname)
187+
raise KeyError("Unknown sampler stat %s" % stat_name)
188188

189-
vals = np.stack([self._get_sampler_stats(varname, i, burn, thin)
189+
vals = np.stack([self._get_sampler_stats(stat_name, i, burn, thin)
190190
for i in sampler_idxs], axis=-1)
191191
if vals.shape[-1] == 1:
192192
return vals[..., 0]
193193
else:
194194
return vals
195195

196-
def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
196+
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
197197
"""Get sampler statistics."""
198198
raise NotImplementedError()
199199

@@ -458,13 +458,13 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
458458
results = [self._straces[chains].get_values(varname, burn, thin)]
459459
return _squeeze_cat(results, combine, squeeze)
460460

461-
def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
461+
def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
462462
chains=None, squeeze=True):
463463
"""Get sampler statistics from the trace.
464464
465465
Parameters
466466
----------
467-
varname : str
467+
stat_name : str
468468
sampler_idx : int or None
469469
burn : int
470470
thin : int
@@ -477,8 +477,8 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
477477
a numpy array of shape (m, n), where `m` is the number of
478478
such samplers, and `n` is the number of samples.
479479
"""
480-
if varname not in self.stat_names:
481-
raise KeyError("Unknown sampler statistic %s" % varname)
480+
if stat_name not in self.stat_names:
481+
raise KeyError("Unknown sampler statistic %s" % stat_name)
482482

483483
if chains is None:
484484
chains = self.chains
@@ -487,7 +487,7 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
487487
except TypeError:
488488
chains = [chains]
489489

490-
results = [self._straces[chain].get_sampler_stats(varname, None, burn, thin)
490+
results = [self._straces[chain].get_sampler_stats(stat_name, None, burn, thin)
491491
for chain in chains]
492492
return _squeeze_cat(results, combine, squeeze)
493493

0 commit comments

Comments
 (0)