Skip to content

Commit 2a86c6b

Browse files
Armavicatwiecki
authored andcommitted
Apply safe ruff fixes
1 parent 6955631 commit 2a86c6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+500
-546
lines changed

pymc/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import subprocess
3030
import sys
3131

32-
from typing import Callable
32+
from collections.abc import Callable
3333

3434

3535
def get_keywords():

pymc/backends/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,11 @@
6363

6464
from collections.abc import Mapping, Sequence
6565
from copy import copy
66-
from typing import Optional, Union
66+
from typing import Optional, TypeAlias, Union
6767

6868
import numpy as np
6969

7070
from pytensor.tensor.variable import TensorVariable
71-
from typing_extensions import TypeAlias
7271

7372
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
7473
from pymc.backends.base import BaseTrace, IBaseTrace
@@ -98,9 +97,9 @@ def _init_trace(
9897
expected_length: int,
9998
chain_number: int,
10099
stats_dtypes: list[dict[str, type]],
101-
trace: Optional[BaseTrace],
100+
trace: BaseTrace | None,
102101
model: Model,
103-
trace_vars: Optional[list[TensorVariable]] = None,
102+
trace_vars: list[TensorVariable] | None = None,
104103
) -> BaseTrace:
105104
"""Initializes a trace backend for a chain."""
106105
strace: BaseTrace
@@ -119,14 +118,14 @@ def _init_trace(
119118

120119
def init_traces(
121120
*,
122-
backend: Optional[TraceOrBackend],
121+
backend: TraceOrBackend | None,
123122
chains: int,
124123
expected_length: int,
125-
step: Union[BlockedStep, CompoundStep],
124+
step: BlockedStep | CompoundStep,
126125
initial_point: Mapping[str, np.ndarray],
127126
model: Model,
128-
trace_vars: Optional[list[TensorVariable]] = None,
129-
) -> tuple[Optional[RunType], Sequence[IBaseTrace]]:
127+
trace_vars: list[TensorVariable] | None = None,
128+
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
130129
"""Initializes a trace recorder for each chain."""
131130
if HAS_MCB and isinstance(backend, Backend):
132131
return init_chain_adapters(

pymc/backends/arviz.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def insert(self, k: str, v, idx: int):
163163
class InferenceDataConverter:
164164
"""Encapsulate InferenceData specific logic."""
165165

166-
model: Optional[Model] = None
167-
posterior_predictive: Optional[Mapping[str, np.ndarray]] = None
168-
predictions: Optional[Mapping[str, np.ndarray]] = None
169-
prior: Optional[Mapping[str, np.ndarray]] = None
166+
model: Model | None = None
167+
posterior_predictive: Mapping[str, np.ndarray] | None = None
168+
predictions: Mapping[str, np.ndarray] | None = None
169+
prior: Mapping[str, np.ndarray] | None = None
170170

171171
def __init__(
172172
self,
@@ -177,11 +177,11 @@ def __init__(
177177
log_likelihood=False,
178178
log_prior=False,
179179
predictions=None,
180-
coords: Optional[CoordSpec] = None,
181-
dims: Optional[DimSpec] = None,
182-
sample_dims: Optional[list] = None,
180+
coords: CoordSpec | None = None,
181+
dims: DimSpec | None = None,
182+
sample_dims: list | None = None,
183183
model=None,
184-
save_warmup: Optional[bool] = None,
184+
save_warmup: bool | None = None,
185185
include_transformed: bool = False,
186186
):
187187
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
@@ -466,15 +466,15 @@ def to_inference_data(self):
466466
def to_inference_data(
467467
trace: Optional["MultiTrace"] = None,
468468
*,
469-
prior: Optional[Mapping[str, Any]] = None,
470-
posterior_predictive: Optional[Mapping[str, Any]] = None,
471-
log_likelihood: Union[bool, Iterable[str]] = False,
472-
log_prior: Union[bool, Iterable[str]] = False,
473-
coords: Optional[CoordSpec] = None,
474-
dims: Optional[DimSpec] = None,
475-
sample_dims: Optional[list] = None,
469+
prior: Mapping[str, Any] | None = None,
470+
posterior_predictive: Mapping[str, Any] | None = None,
471+
log_likelihood: bool | Iterable[str] = False,
472+
log_prior: bool | Iterable[str] = False,
473+
coords: CoordSpec | None = None,
474+
dims: DimSpec | None = None,
475+
sample_dims: list | None = None,
476476
model: Optional["Model"] = None,
477-
save_warmup: Optional[bool] = None,
477+
save_warmup: bool | None = None,
478478
include_transformed: bool = False,
479479
) -> InferenceData:
480480
"""Convert pymc data into an InferenceData object.
@@ -543,10 +543,10 @@ def predictions_to_inference_data(
543543
predictions,
544544
posterior_trace: Optional["MultiTrace"] = None,
545545
model: Optional["Model"] = None,
546-
coords: Optional[CoordSpec] = None,
547-
dims: Optional[DimSpec] = None,
548-
sample_dims: Optional[list] = None,
549-
idata_orig: Optional[InferenceData] = None,
546+
coords: CoordSpec | None = None,
547+
dims: DimSpec | None = None,
548+
sample_dims: list | None = None,
549+
idata_orig: InferenceData | None = None,
550550
inplace: bool = False,
551551
) -> InferenceData:
552552
"""Translate out-of-sample predictions into ``InferenceData``.

pymc/backends/base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from collections.abc import Mapping, Sequence, Sized
2626
from typing import (
2727
Any,
28-
Optional,
2928
TypeVar,
30-
Union,
3129
cast,
3230
)
3331

@@ -53,7 +51,7 @@ class IBaseTrace(ABC, Sized):
5351
varnames: list[str]
5452
"""Names of tracked variables."""
5553

56-
sampler_vars: list[dict[str, Union[type, np.dtype]]]
54+
sampler_vars: list[dict[str, type | np.dtype]]
5755
"""Sampler stats for each sampler."""
5856

5957
def __len__(self):
@@ -75,7 +73,7 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
7573
raise NotImplementedError()
7674

7775
def get_sampler_stats(
78-
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
76+
self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1
7977
) -> np.ndarray:
8078
"""Get sampler statistics from the trace.
8179
@@ -219,7 +217,7 @@ def __getitem__(self, idx):
219217
raise ValueError("Can only index with slice or integer")
220218

221219
def get_sampler_stats(
222-
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
220+
self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1
223221
) -> np.ndarray:
224222
"""Get sampler statistics from the trace.
225223
@@ -443,7 +441,7 @@ def get_values(
443441
burn: int = 0,
444442
thin: int = 1,
445443
combine: bool = True,
446-
chains: Optional[Union[int, Sequence[int]]] = None,
444+
chains: int | Sequence[int] | None = None,
447445
squeeze: bool = True,
448446
) -> list[np.ndarray]:
449447
"""Get values from traces.
@@ -482,9 +480,9 @@ def get_sampler_stats(
482480
burn: int = 0,
483481
thin: int = 1,
484482
combine: bool = True,
485-
chains: Optional[Union[int, Sequence[int]]] = None,
483+
chains: int | Sequence[int] | None = None,
486484
squeeze: bool = True,
487-
) -> Union[list[np.ndarray], np.ndarray]:
485+
) -> list[np.ndarray] | np.ndarray:
488486
"""Get sampler statistics from the trace.
489487
490488
Note: This implementation attempts to squeeze object arrays into a consistent dtype,
@@ -534,7 +532,7 @@ def _slice(self, slice: slice):
534532
trace._report = self._report._slice(*idxs)
535533
return trace
536534

537-
def point(self, idx: int, chain: Optional[int] = None) -> dict[str, np.ndarray]:
535+
def point(self, idx: int, chain: int | None = None) -> dict[str, np.ndarray]:
538536
"""Return a dictionary of point values at `idx`.
539537
540538
Parameters

pymc/backends/mcbackend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pickle
1818

1919
from collections.abc import Mapping, Sequence
20-
from typing import Any, Optional, Union, cast
20+
from typing import Any, cast
2121

2222
import hagelkorn
2323
import mcbackend as mcb
@@ -144,7 +144,7 @@ def _get_stats(self, fname: str, slc: slice) -> np.ndarray:
144144
return values
145145

146146
def get_sampler_stats(
147-
self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1
147+
self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1
148148
) -> np.ndarray:
149149
slc = slice(burn, None, thin)
150150
# When there's just one sampler, default to remove the sampler dimension
@@ -204,7 +204,7 @@ def point(self, idx: int) -> dict[str, np.ndarray]:
204204
def make_runmeta_and_point_fn(
205205
*,
206206
initial_point: Mapping[str, np.ndarray],
207-
step: Union[CompoundStep, BlockedStep],
207+
step: CompoundStep | BlockedStep,
208208
model: Model,
209209
) -> tuple[mcb.RunMeta, PointFunc]:
210210
variables, point_fn = get_variables_and_point_fn(model, initial_point)
@@ -254,7 +254,7 @@ def init_chain_adapters(
254254
backend: mcb.Backend,
255255
chains: int,
256256
initial_point: Mapping[str, np.ndarray],
257-
step: Union[CompoundStep, BlockedStep],
257+
step: CompoundStep | BlockedStep,
258258
model: Model,
259259
) -> tuple[mcb.Run, list[ChainRecordAdapter]]:
260260
"""Create an McBackend metadata description for the MCMC run.

pymc/backends/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Store sampling values in memory as a NumPy array.
1818
"""
1919

20-
from typing import Any, Optional
20+
from typing import Any
2121

2222
import numpy as np
2323

@@ -210,7 +210,7 @@ def _slice_as_ndarray(strace, idx):
210210

211211

212212
def point_list_to_multitrace(
213-
point_list: list[dict[str, np.ndarray]], model: Optional[Model] = None
213+
point_list: list[dict[str, np.ndarray]], model: Model | None = None
214214
) -> MultiTrace:
215215
"""transform point list into MultiTrace"""
216216
_model = modelcontext(model)

pymc/backends/report.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import itertools
1717
import logging
1818

19-
from typing import Optional
20-
2119
from pymc.stats.convergence import _LEVELS, SamplerWarning
2220

2321
logger = logging.getLogger(__name__)
@@ -44,17 +42,17 @@ def ok(self):
4442
return all(_LEVELS[warn.level] < _LEVELS["warn"] for warn in self._warnings)
4543

4644
@property
47-
def n_tune(self) -> Optional[int]:
45+
def n_tune(self) -> int | None:
4846
"""Number of tune iterations - not necessarily kept in trace!"""
4947
return self._n_tune
5048

5149
@property
52-
def n_draws(self) -> Optional[int]:
50+
def n_draws(self) -> int | None:
5351
"""Number of draw iterations."""
5452
return self._n_draws
5553

5654
@property
57-
def t_sampling(self) -> Optional[float]:
55+
def t_sampling(self) -> float | None:
5856
"""
5957
Number of seconds that the sampling procedure took.
6058

pymc/blocking.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,27 @@
2020

2121
from __future__ import annotations
2222

23-
from collections.abc import Sequence
23+
from collections.abc import Callable, Sequence
2424
from functools import partial
2525
from typing import (
2626
Any,
27-
Callable,
2827
Generic,
2928
NamedTuple,
30-
Optional,
29+
TypeAlias,
3130
TypeVar,
32-
Union,
3331
)
3432

3533
import numpy as np
3634

37-
from typing_extensions import TypeAlias
38-
3935
__all__ = ["DictToArrayBijection"]
4036

4137

4238
T = TypeVar("T")
4339
PointType: TypeAlias = dict[str, np.ndarray]
4440
StatsDict: TypeAlias = dict[str, Any]
4541
StatsType: TypeAlias = list[StatsDict]
46-
StatDtype: TypeAlias = Union[type, np.dtype]
47-
StatShape: TypeAlias = Optional[Sequence[Optional[int]]]
42+
StatDtype: TypeAlias = type | np.dtype
43+
StatShape: TypeAlias = Sequence[int | None] | None
4844

4945

5046
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for

pymc/data.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from collections.abc import Sequence
2020
from copy import copy
21-
from typing import Optional, Union, cast
21+
from typing import cast
2222

2323
import numpy as np
2424
import pandas as pd
@@ -203,10 +203,10 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
203203

204204
def determine_coords(
205205
model,
206-
value: Union[pd.DataFrame, pd.Series, xr.DataArray],
207-
dims: Optional[Sequence[Optional[str]]] = None,
208-
coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None,
209-
) -> tuple[dict[str, Union[Sequence, np.ndarray]], Sequence[Optional[str]]]:
206+
value: pd.DataFrame | pd.Series | xr.DataArray,
207+
dims: Sequence[str | None] | None = None,
208+
coords: dict[str, Sequence | np.ndarray] | None = None,
209+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
210210
"""Determines coordinate values from data or the model (via ``dims``)."""
211211
if coords is None:
212212
coords = {}
@@ -260,8 +260,8 @@ def ConstantData(
260260
name: str,
261261
value,
262262
*,
263-
dims: Optional[Sequence[str]] = None,
264-
coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None,
263+
dims: Sequence[str] | None = None,
264+
coords: dict[str, Sequence | np.ndarray] | None = None,
265265
infer_dims_and_coords=False,
266266
**kwargs,
267267
) -> TensorConstant:
@@ -290,8 +290,8 @@ def MutableData(
290290
name: str,
291291
value,
292292
*,
293-
dims: Optional[Sequence[str]] = None,
294-
coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None,
293+
dims: Sequence[str] | None = None,
294+
coords: dict[str, Sequence | np.ndarray] | None = None,
295295
infer_dims_and_coords=False,
296296
**kwargs,
297297
) -> SharedVariable:
@@ -320,12 +320,12 @@ def Data(
320320
name: str,
321321
value,
322322
*,
323-
dims: Optional[Sequence[str]] = None,
324-
coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None,
323+
dims: Sequence[str] | None = None,
324+
coords: dict[str, Sequence | np.ndarray] | None = None,
325325
infer_dims_and_coords=False,
326-
mutable: Optional[bool] = None,
326+
mutable: bool | None = None,
327327
**kwargs,
328-
) -> Union[SharedVariable, TensorConstant]:
328+
) -> SharedVariable | TensorConstant:
329329
"""Data container that registers a data variable with the model.
330330
331331
Depending on the ``mutable`` setting (default: True), the variable

0 commit comments

Comments
 (0)