Skip to content

Commit 3b3be49

Browse files
ArmavicaricardoV94
authored andcommitted
Apply unsafe f-string related ruff fixes
1 parent fb7e1b1 commit 3b3be49

File tree

10 files changed

+26
-30
lines changed

10 files changed

+26
-30
lines changed

pymc/backends/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _set_sampler_vars(self, sampler_vars):
186186
for stats in sampler_vars:
187187
for key, dtype in stats.items():
188188
if dtypes.setdefault(key, dtype) != dtype:
189-
raise ValueError("Sampler statistic %s appears with " "different types." % key)
189+
raise ValueError(f"Sampler statistic {key} appears with different types.")
190190

191191
self.sampler_vars = sampler_vars
192192

@@ -247,7 +247,7 @@ def get_sampler_stats(
247247

248248
sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if stat_name in s]
249249
if not sampler_idxs:
250-
raise KeyError("Unknown sampler stat %s" % stat_name)
250+
raise KeyError(f"Unknown sampler stat {stat_name}")
251251

252252
vals = np.stack(
253253
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
@@ -388,7 +388,7 @@ def __getitem__(self, idx):
388388
return self.get_values(var, burn=burn, thin=thin)
389389
if var in self.stat_names:
390390
return self.get_sampler_stats(var, burn=burn, thin=thin)
391-
raise KeyError("Unknown variable %s" % var)
391+
raise KeyError(f"Unknown variable {var}")
392392

393393
_attrs = {"_straces", "varnames", "chains", "stat_names", "_report"}
394394

@@ -512,7 +512,7 @@ def get_sampler_stats(
512512
List or ndarray depending on parameters.
513513
"""
514514
if stat_name not in self.stat_names:
515-
raise KeyError("Unknown sampler statistic %s" % stat_name)
515+
raise KeyError(f"Unknown sampler statistic {stat_name}")
516516

517517
if chains is None:
518518
chains = self.chains

pymc/distributions/multivariate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,11 +1047,11 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10471047
tril_testval = None
10481048

10491049
c = pt.sqrt(
1050-
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
1050+
ChiSquared(f"{name}_c", nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
10511051
)
1052-
pm._log.info("Added new variable %s_c to model diagonal of Wishart." % name)
1053-
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, initval=tril_testval)
1054-
pm._log.info("Added new variable %s_z to model off-diagonals of Wishart." % name)
1052+
pm._log.info(f"Added new variable {name}_c to model diagonal of Wishart.")
1053+
z = Normal(f"{name}_z", 0.0, 1.0, shape=n_tril, initval=tril_testval)
1054+
pm._log.info(f"Added new variable {name}_z to model off-diagonals of Wishart.")
10551055
# Construct A matrix
10561056
A = pt.zeros(S.shape, dtype=np.float32)
10571057
A = pt.set_subtensor(A[diag_idx], c)

pymc/sampling/forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def sample_prior_predictive(
419419

420420
data = {k: np.stack(v) for k, v in zip(names, values)}
421421
if data is None:
422-
raise AssertionError("No variables sampled: attempting to sample %s" % names)
422+
raise AssertionError(f"No variables sampled: attempting to sample {names}")
423423

424424
prior: dict[str, np.ndarray] = {}
425425
for var_name in vars_:
@@ -765,8 +765,7 @@ def sample_posterior_predictive(
765765
samples = len(_trace)
766766
else:
767767
raise TypeError(
768-
"Do not know how to compute number of samples for trace argument of type %s"
769-
% type(_trace)
768+
f"Do not know how to compute number of samples for trace argument of type {type(_trace)}"
770769
)
771770

772771
assert samples is not None

pymc/sampling/mcmc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -697,10 +697,7 @@ def joined_blas_limiter():
697697
msg = "Tuning was enabled throughout the whole trace."
698698
_log.warning(msg)
699699
elif draws < 100:
700-
msg = (
701-
"Only %s samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
702-
% draws
703-
)
700+
msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
704701
_log.warning(msg)
705702

706703
auto_nuts_init = True

pymc/sampling/parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, exc, tb):
5858
tb = traceback.format_exception(type(exc), exc, tb)
5959
tb = "".join(tb)
6060
self.exc = exc
61-
self.tb = '\n"""\n%s"""' % tb
61+
self.tb = f'\n"""\n{tb}"""'
6262

6363
def __reduce__(self):
6464
return rebuild_exc, (self.exc, self.tb)
@@ -216,7 +216,7 @@ def __init__(
216216
mp_ctx,
217217
):
218218
self.chain = chain
219-
process_name = "worker_chain_%s" % chain
219+
process_name = f"worker_chain_{chain}"
220220
self._msg_pipe, remote_conn = multiprocessing.Pipe()
221221

222222
self._shared_point = {}
@@ -228,7 +228,7 @@ def __init__(
228228
size *= int(dim)
229229
size *= dtype.itemsize
230230
if size != ctypes.c_size_t(size).value:
231-
raise ValueError("Variable %s is too large" % name)
231+
raise ValueError(f"Variable {name} is too large")
232232

233233
array = mp_ctx.RawArray("c", size)
234234
self._shared_point[name] = (array, shape, dtype)
@@ -388,7 +388,7 @@ def __init__(
388388
mp_ctx=None,
389389
):
390390
if any(len(arg) != chains for arg in [seeds, start_points]):
391-
raise ValueError("Number of seeds and start_points must be %s." % chains)
391+
raise ValueError(f"Number of seeds and start_points must be {chains}.")
392392

393393
if mp_ctx is None or isinstance(mp_ctx, str):
394394
# Closes issue https://github.com/pymc-devs/pymc/issues/3849

pymc/step_methods/hmc/integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, potential: QuadPotential, logp_dlogp_func):
5151
def compute_state(self, q: RaveledVars, p: RaveledVars):
5252
"""Compute Hamiltonian functions using a position and momentum."""
5353
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
54-
raise ValueError("Invalid dtype. Must be %s" % self._dtype)
54+
raise ValueError(f"Invalid dtype. Must be {self._dtype}")
5555

5656
logp, dlogp = self._logp_dlogp_func(q)
5757

pymc/step_methods/metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
elif S.ndim == 2:
179179
self.proposal_dist = MultivariateNormalProposal(S)
180180
else:
181-
raise ValueError("Invalid rank for variance: %s" % S.ndim)
181+
raise ValueError(f"Invalid rank for variance: {S.ndim}")
182182

183183
self.scaling = np.atleast_1d(scaling).astype("d")
184184
self.tune = tune

pymc/variational/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def _maybe_score(self, score):
7272
score = returns_loss
7373
elif score and not returns_loss:
7474
warnings.warn(
75-
"method `fit` got `score == True` but %s "
76-
"does not return loss. Ignoring `score` argument" % self.objective.op
75+
f"method `fit` got `score == True` but {self.objective.op} "
76+
"does not return loss. Ignoring `score` argument"
7777
)
7878
score = False
7979
else:

pymc/variational/opvi.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def step_function(
375375
if fn_kwargs is None:
376376
fn_kwargs = {}
377377
if score and not self.op.returns_loss:
378-
raise NotImplementedError("%s does not have loss" % self.op)
378+
raise NotImplementedError(f"{self.op} does not have loss")
379379
updates = self.updates(
380380
obj_n_mc=obj_n_mc,
381381
tf_n_mc=tf_n_mc,
@@ -416,7 +416,7 @@ def score_function(
416416
if fn_kwargs is None:
417417
fn_kwargs = {}
418418
if not self.op.returns_loss:
419-
raise NotImplementedError("%s does not have loss" % self.op)
419+
raise NotImplementedError(f"{self.op} does not have loss")
420420
if more_replacements is None:
421421
more_replacements = {}
422422
loss = self(sc_n_mc, more_replacements=more_replacements)
@@ -496,13 +496,13 @@ def apply(self, f): # pragma: no cover
496496
def __call__(self, f=None):
497497
if self.has_test_function:
498498
if f is None:
499-
raise ParametrizationError("Operator %s requires TestFunction" % self)
499+
raise ParametrizationError(f"Operator {self} requires TestFunction")
500500
else:
501501
if not isinstance(f, TestFunction):
502502
f = TestFunction.from_function(f)
503503
else:
504504
if f is not None:
505-
warnings.warn("TestFunction for %s is redundant and removed" % self, stacklevel=3)
505+
warnings.warn(f"TestFunction for {self} is redundant and removed", stacklevel=3)
506506
else:
507507
pass
508508
f = TestFunction()
@@ -555,7 +555,7 @@ def setup(self, approx):
555555
@classmethod
556556
def from_function(cls, f):
557557
if not callable(f):
558-
raise ParametrizationError("Need callable, got %r" % f)
558+
raise ParametrizationError(f"Need callable, got {f!r}")
559559
obj = TestFunction()
560560
obj.__call__ = f
561561
return obj
@@ -1512,7 +1512,7 @@ def vars_names(vs):
15121512
found.name = name + "_vi_random_slice"
15131513
break
15141514
else:
1515-
raise KeyError("%r not found" % name)
1515+
raise KeyError(f"{name!r} not found")
15161516
return found
15171517

15181518
@node_property

tests/variational/test_opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model):
261261

262262
def test_logq_globals(three_var_approx):
263263
if not three_var_approx.has_logq:
264-
pytest.skip("%s does not implement logq" % three_var_approx)
264+
pytest.skip(f"{three_var_approx} does not implement logq")
265265
approx = three_var_approx
266266
logq, symbolic_logq = approx.set_size_and_deterministic(
267267
[approx.logq, approx.symbolic_logq], 1, 0

0 commit comments

Comments
 (0)