Skip to content

Commit e6767ab

Browse files
juanitorduztwiecki
andauthored
pre-commit update ruff 0.9.1 (#7648)
Co-authored-by: Thomas Wiecki <[email protected]>
1 parent bd519d4 commit e6767ab

File tree

18 files changed

+51
-54
lines changed

18 files changed

+51
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ repos:
4848
# - --exclude=binder/
4949
# - --exclude=versioneer.py
5050
- repo: https://github.com/astral-sh/ruff-pre-commit
51-
rev: v0.8.4
51+
rev: v0.9.1
5252
hooks:
5353
- id: ruff
5454
args: [--fix, --show-fixes]

docs/source/learn/core_notebooks/pymc_pytensor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@
18491849
"print(\n",
18501850
" f\"\"\"\n",
18511851
"mu_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=2)}\n",
1852-
"sigma_log_value -> {- 10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
1852+
"sigma_log_value -> {-10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n",
18531853
"x_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=np.exp(-10))}\n",
18541854
"\"\"\"\n",
18551855
")"

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def determine_coords(
257257
if isinstance(value, np.ndarray) and dims is not None:
258258
if len(dims) != value.ndim:
259259
raise pm.exceptions.ShapeError(
260-
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
260+
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
261261
actual=value.shape,
262262
expected=value.ndim,
263263
)

pymc/distributions/continuous.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -992,8 +992,7 @@ def get_mu_lam_phi(mu, lam, phi):
992992
return mu, lam, lam / mu
993993

994994
raise ValueError(
995-
"Wald distribution must specify either mu only, "
996-
"mu and lam, mu and phi, or lam and phi."
995+
"Wald distribution must specify either mu only, mu and lam, mu and phi, or lam and phi."
997996
)
998997

999998
def logp(value, mu, lam, alpha):
@@ -1603,8 +1602,7 @@ def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
16031602
def get_kappa(cls, kappa=None, q=None):
16041603
if kappa is not None and q is not None:
16051604
raise ValueError(
1606-
"Incompatible parameterization. Either use "
1607-
"kappa or q to specify the distribution."
1605+
"Incompatible parameterization. Either use kappa or q to specify the distribution."
16081606
)
16091607
elif q is not None:
16101608
if isinstance(q, Variable):
@@ -3483,7 +3481,7 @@ def get_nu_b(cls, nu, b, sigma):
34833481
elif nu is not None and b is None:
34843482
b = nu / sigma
34853483
return nu, b, sigma
3486-
raise ValueError("Rice distribution must specify either nu" " or b.")
3484+
raise ValueError("Rice distribution must specify either nu or b.")
34873485

34883486
def support_point(rv, size, nu, sigma):
34893487
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)

pymc/distributions/multivariate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ class MvNormal(Continuous):
247247
data = np.random.multivariate_normal(mu, true_cov, 10)
248248
249249
sd_dist = pm.Exponential.dist(1.0, shape=3)
250-
chol, corr, stds = pm.LKJCholeskyCov("chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True)
250+
chol, corr, stds = pm.LKJCholeskyCov(
251+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
252+
)
251253
vals = pm.MvNormal("vals", mu=mu, chol=chol, observed=data)
252254
253255
For unobserved values it can be better to use a non-centered
@@ -2793,9 +2795,9 @@ def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs):
27932795

27942796
support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1)
27952797

2796-
assert n_zerosum_axes == pt.get_vector_length(
2797-
support_shape
2798-
), "support_shape has to be as long as n_zerosum_axes"
2798+
assert n_zerosum_axes == pt.get_vector_length(support_shape), (
2799+
"support_shape has to be as long as n_zerosum_axes"
2800+
)
27992801

28002802
return super().dist([sigma, support_shape], **kwargs)
28012803

pymc/gp/cov.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,7 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable:
328328
check = Counter([isinstance(factor, Covariance) for factor in self._factor_list])
329329
if check.get(True, 0) >= 2:
330330
raise NotImplementedError(
331-
"The power spectral density of products of covariance "
332-
"functions is not implemented."
331+
"The power spectral density of products of covariance functions is not implemented."
333332
)
334333
return reduce(mul, self._merge_factors_psd(omega))
335334

pymc/gp/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ def plot_gp_dist(
211211
samples_kwargs = {}
212212
if np.any(np.isnan(samples)):
213213
warnings.warn(
214-
"There are `nan` entries in the [samples] arguments. "
215-
"The plot will not contain a band!",
214+
"There are `nan` entries in the [samples] arguments. The plot will not contain a band!",
216215
UserWarning,
217216
)
218217

pymc/sampling/jax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
108108

109109
if any(var.default_update is not None for var in shared_variables):
110110
raise ValueError(
111-
"Graph contains shared variables with default_update which cannot "
112-
"be safely replaced."
111+
"Graph contains shared variables with default_update which cannot be safely replaced."
113112
)
114113

115114
replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
@@ -360,7 +359,7 @@ def _sample_blackjax_nuts(
360359
map_fn = jax.vmap
361360
else:
362361
raise ValueError(
363-
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
362+
"Only supporting the following methods to draw chains: 'parallel' or 'vectorized'"
364363
)
365364

366365
if chains == 1:

pymc/sampling/mcmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def _sample_return(
10001000
total_draws = draws_per_chain.sum()
10011001

10021002
_log.info(
1003-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations '
1003+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations "
10041004
f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) "
10051005
f"took {t_sampling:.0f} seconds."
10061006
)
@@ -1062,8 +1062,8 @@ def _sample_return(
10621062

10631063
n_chains = len(mtrace.chains)
10641064
_log.info(
1065-
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
1066-
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
1065+
f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {n_tune:_d} tune and {n_draws:_d} draw iterations "
1066+
f"({n_tune * n_chains:_d} + {n_draws * n_chains:_d} draws total) "
10671067
f"took {t_sampling:.0f} seconds."
10681068
)
10691069

pymc/sampling/population.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,9 @@ def _prepare_iter_population(
386386

387387
# 2. Set up the steppers
388388
steppers: list[Step] = []
389-
assert (
390-
len(rngs) == nchains
391-
), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}"
389+
assert len(rngs) == nchains, (
390+
f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}"
391+
)
392392
for c, rng in enumerate(rngs):
393393
# need independent samplers for each chain
394394
# it is important to copy the actual steppers (but not the delta_logp)

pymc/step_methods/compound.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def sampling_state(self) -> DataClassState:
282282

283283
@sampling_state.setter
284284
def sampling_state(self, state: DataClassState):
285-
assert isinstance(
286-
state, self._state_class
287-
), f"Invalid sampling state class {type(state)}. Expected {self._state_class}"
285+
assert isinstance(state, self._state_class), (
286+
f"Invalid sampling state class {type(state)}. Expected {self._state_class}"
287+
)
288288
for method, state_method in zip(self.methods, state.methods):
289289
method.sampling_state = state_method
290290

pymc/step_methods/state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def sampling_state(self) -> DataClassState:
9090
@sampling_state.setter
9191
def sampling_state(self, state: DataClassState):
9292
state_class = self._state_class
93-
assert isinstance(
94-
state, state_class
95-
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
93+
assert isinstance(state, state_class), (
94+
f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
95+
)
9696
for field in fields(state_class):
9797
is_tensor_name = field.metadata.get("tensor_name", False)
9898
state_val = deepcopy(getattr(state, field.name))

pymc/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -964,9 +964,9 @@ def check_rv_size(self):
964964
assert actual == expected_symbolic == expected
965965

966966
def validate_tests_list(self):
967-
assert len(self.checks_to_run) == len(
968-
set(self.checks_to_run)
969-
), "There are duplicates in the list of checks_to_run"
967+
assert len(self.checks_to_run) == len(set(self.checks_to_run)), (
968+
"There are duplicates in the list of checks_to_run"
969+
)
970970

971971

972972
def seeded_scipy_distribution_builder(dist_name: str) -> Callable:

pymc/variational/opvi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,9 @@ class Group(WithMemoization):
710710

711711
@classmethod
712712
def register(cls, sbcls):
713-
assert (
714-
frozenset(sbcls.__param_spec__) not in cls.__param_registry
715-
), "Duplicate __param_spec__"
713+
assert frozenset(sbcls.__param_spec__) not in cls.__param_registry, (
714+
"Duplicate __param_spec__"
715+
)
716716
cls.__param_registry[frozenset(sbcls.__param_spec__)] = sbcls
717717
assert sbcls.short_name not in cls.__name_registry, "Duplicate short_name"
718718
cls.__name_registry[sbcls.short_name] = sbcls
@@ -1234,7 +1234,7 @@ def __init__(self, groups, model=None):
12341234
for g in groups:
12351235
if g.group is None:
12361236
if rest is not None:
1237-
raise GroupError("More than one group is specified for " "the rest variables")
1237+
raise GroupError("More than one group is specified for the rest variables")
12381238
else:
12391239
rest = g
12401240
else:

pymc/variational/updates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7):
10061006
elif ndim in [3, 4, 5]: # Conv{1,2,3}DLayer
10071007
sum_over = tuple(range(1, ndim))
10081008
else:
1009-
raise ValueError(f"Unsupported tensor dimensionality {ndim}." "Must specify `norm_axes`")
1009+
raise ValueError(f"Unsupported tensor dimensionality {ndim}. Must specify `norm_axes`")
10101010

10111011
dtype = np.dtype(pytensor.config.floatX).type
10121012
norms = pt.sqrt(pt.sum(pt.sqr(tensor_var), axis=sum_over, keepdims=True))

tests/distributions/test_multivariate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,14 +1531,14 @@ class TestZeroSumNormal:
15311531
def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True):
15321532
if check_zerosum_axes:
15331533
for ax in axes_to_check:
1534-
assert np.isclose(
1535-
random_samples.mean(axis=ax), 0
1536-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1534+
assert np.isclose(random_samples.mean(axis=ax), 0).all(), (
1535+
f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1536+
)
15371537
else:
15381538
for ax in axes_to_check:
1539-
assert not np.isclose(
1540-
random_samples.mean(axis=ax), 0
1541-
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1539+
assert not np.isclose(random_samples.mean(axis=ax), 0).all(), (
1540+
f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1541+
)
15421542

15431543
@pytest.mark.parametrize(
15441544
"dims, n_zerosum_axes",

tests/gp/test_hsgp_approx.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_mean_invariance(self):
135135
with model:
136136
pm.set_data({"X": x_new})
137137

138-
assert np.allclose(
139-
gp._X_center, original_center
140-
), "gp._X_center should not change after updating data for out-of-sample predictions."
138+
assert np.allclose(gp._X_center, original_center), (
139+
"gp._X_center should not change after updating data for out-of-sample predictions."
140+
)
141141

142142
def test_parametrization(self):
143143
err_msg = (
@@ -188,9 +188,9 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
188188

189189
n_coeffs = model.f1_hsgp_coeffs.type.shape[0]
190190
if drop_first:
191-
assert (
192-
n_coeffs == n_basis - 1
193-
), f"one basis vector should have been dropped, {n_coeffs}"
191+
assert n_coeffs == n_basis - 1, (
192+
f"one basis vector should have been dropped, {n_coeffs}"
193+
)
194194
else:
195195
assert n_coeffs == n_basis, "one was dropped when it shouldn't have been"
196196

tests/test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def test_explicit_coords(self, seeded_test):
318318
N_cols = 7
319319
data = np.random.uniform(size=(N_rows, N_cols))
320320
coords = {
321-
"rows": [f"R{r+1}" for r in range(N_rows)],
322-
"columns": [f"C{c+1}" for c in range(N_cols)],
321+
"rows": [f"R{r + 1}" for r in range(N_rows)],
322+
"columns": [f"C{c + 1}" for c in range(N_cols)],
323323
}
324324
# pass coordinates explicitly, use numpy array in Data container
325325
with pm.Model(coords=coords) as pmodel:
@@ -391,7 +391,7 @@ def test_implicit_coords_dataframe(self, seeded_test):
391391
N_cols = 7
392392
df_data = pd.DataFrame()
393393
for c in range(N_cols):
394-
df_data[f"Column {c+1}"] = np.random.normal(size=(N_rows,))
394+
df_data[f"Column {c + 1}"] = np.random.normal(size=(N_rows,))
395395
df_data.index.name = "rows"
396396
df_data.columns.name = "columns"
397397

0 commit comments

Comments
 (0)