Skip to content

Commit 7bb6bf8

Browse files
Armavicatwiecki
authored andcommitted
Fix UP038 (isinstance(..., X | Y))
1 parent 2a86c6b commit 7bb6bf8

38 files changed

+75
-79
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def is_data(name, var, model) -> bool:
7878
and var not in model.potentials
7979
and var not in model.value_vars
8080
and name not in observations
81-
and isinstance(var, (Constant, SharedVariable))
81+
and isinstance(var, Constant | SharedVariable)
8282
)
8383

8484
# The assumption is that constants (like pm.Data) are named

pymc/distributions/censored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Censored(Distribution):
8989
@classmethod
9090
def dist(cls, dist, lower, upper, **kwargs):
9191
if not isinstance(dist, TensorVariable) or not isinstance(
92-
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
92+
dist.owner.op, RandomVariable | SymbolicRandomVariable
9393
):
9494
raise ValueError(
9595
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def rewrite_support_point_scan_node(self, node):
102102

103103
for nd in local_fgraph_topo:
104104
if nd not in to_replace_set and isinstance(
105-
nd.op, (RandomVariable, SymbolicRandomVariable)
105+
nd.op, RandomVariable | SymbolicRandomVariable
106106
):
107107
replace_with_support_point.append(nd.out)
108108
to_replace_set.add(nd)
@@ -132,7 +132,7 @@ def add_requirements(self, fgraph):
132132

133133
def apply(self, fgraph):
134134
for node in fgraph.toposort():
135-
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
135+
if isinstance(node.op, RandomVariable | SymbolicRandomVariable):
136136
fgraph.replace(node.out, support_point(node.out))
137137
elif isinstance(node.op, Scan):
138138
new_node = self.rewrite_support_point_scan_node(node)
@@ -837,7 +837,7 @@ def custom_dist_get_support_point(op, rv, size, *params):
837837
*[
838838
p
839839
for p in params
840-
if not isinstance(p.type, (RandomType, RandomGeneratorType))
840+
if not isinstance(p.type, RandomType | RandomGeneratorType)
841841
],
842842
)
843843

pymc/distributions/mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class Mixture(Distribution):
178178

179179
@classmethod
180180
def dist(cls, w, comp_dists, **kwargs):
181-
if not isinstance(comp_dists, (tuple, list)):
181+
if not isinstance(comp_dists, tuple | list):
182182
# comp_dists is a single component
183183
comp_dists = [comp_dists]
184184
elif len(comp_dists) == 1:
@@ -204,7 +204,7 @@ def dist(cls, w, comp_dists, **kwargs):
204204
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
205205
# and resize them
206206
if not isinstance(dist, TensorVariable) or not isinstance(
207-
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
207+
dist.owner.op, RandomVariable | SymbolicRandomVariable
208208
):
209209
raise ValueError(
210210
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"

pymc/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10831083

10841084
def _lkj_normalizing_constant(eta, n):
10851085
# TODO: This is mixing python branching with the potentially symbolic n and eta variables
1086-
if not isinstance(eta, (int, float)):
1086+
if not isinstance(eta, int | float):
10871087
raise NotImplementedError("eta must be an int or float")
10881088
if not isinstance(n, int):
10891089
raise NotImplementedError("n must be an integer")
@@ -1185,7 +1185,7 @@ def dist(cls, n, eta, sd_dist, **kwargs):
11851185
if not (
11861186
isinstance(sd_dist, Variable)
11871187
and sd_dist.owner is not None
1188-
and isinstance(sd_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
1188+
and isinstance(sd_dist.owner.op, RandomVariable | SymbolicRandomVariable)
11891189
and sd_dist.owner.op.ndim_supp < 2
11901190
):
11911191
raise TypeError("sd_dist must be a scalar or vector distribution variable")
@@ -2262,7 +2262,7 @@ def logp(value, mu, W, alpha, tau):
22622262
TensorVariable
22632263
"""
22642264

2265-
sparse = isinstance(W, (pytensor.sparse.SparseConstant, pytensor.sparse.SparseVariable))
2265+
sparse = isinstance(W, pytensor.sparse.SparseConstant | pytensor.sparse.SparseVariable)
22662266

22672267
if sparse:
22682268
D = sp_sum(W, axis=0)

pymc/distributions/shape_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def convert_dims(dims: Dims | None) -> StrongDims | None:
193193

194194
if isinstance(dims, str):
195195
dims = (dims,)
196-
elif isinstance(dims, (list, tuple)):
196+
elif isinstance(dims, list | tuple):
197197
dims = tuple(dims)
198198
else:
199199
raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}")
@@ -209,7 +209,7 @@ def convert_shape(shape: Shape) -> StrongShape | None:
209209
shape = (shape,)
210210
elif isinstance(shape, TensorVariable) and shape.ndim == 1:
211211
shape = tuple(shape)
212-
elif isinstance(shape, (list, tuple)):
212+
elif isinstance(shape, list | tuple):
213213
shape = tuple(shape)
214214
else:
215215
raise ValueError(
@@ -227,7 +227,7 @@ def convert_size(size: Size) -> StrongSize | None:
227227
size = (size,)
228228
elif isinstance(size, TensorVariable) and size.ndim == 1:
229229
size = tuple(size)
230-
elif isinstance(size, (list, tuple)):
230+
elif isinstance(size, list | tuple):
231231
size = tuple(size)
232232
else:
233233
raise ValueError(

pymc/distributions/timeseries.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari
8888
if not (
8989
isinstance(init_dist, pt.TensorVariable)
9090
and init_dist.owner is not None
91-
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
91+
and isinstance(init_dist.owner.op, RandomVariable | SymbolicRandomVariable)
9292
):
9393
raise TypeError("init_dist must be a distribution variable")
9494
check_dist_not_registered(init_dist)
9595

9696
if not (
9797
isinstance(innovation_dist, pt.TensorVariable)
9898
and innovation_dist.owner is not None
99-
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
99+
and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable)
100100
):
101101
raise TypeError("innovation_dist must be a distribution variable")
102102
check_dist_not_registered(innovation_dist)
@@ -129,7 +129,7 @@ def get_steps(cls, innovation_dist, steps, shape, dims, observed):
129129
if not (
130130
isinstance(innovation_dist, pt.TensorVariable)
131131
and innovation_dist.owner is not None
132-
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
132+
and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable)
133133
):
134134
raise TypeError("innovation_dist must be a distribution variable")
135135

@@ -549,7 +549,7 @@ def dist(
549549

550550
if init_dist is not None:
551551
if not isinstance(init_dist, TensorVariable) or not isinstance(
552-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
552+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
553553
):
554554
raise ValueError(
555555
f"Init dist must be a distribution created via the `.dist()` API, "
@@ -948,7 +948,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
948948

949949
if init_dist is not None:
950950
if not isinstance(init_dist, TensorVariable) or not isinstance(
951-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
951+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
952952
):
953953
raise ValueError(
954954
f"Init dist must be a distribution created via the `.dist()` API, "

pymc/gp/cov.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,7 @@ def _merge_factors_cov(self, X, Xs=None, diag=False):
256256

257257
elif isinstance(
258258
factor,
259-
(
260-
TensorConstant,
261-
TensorVariable,
262-
TensorSharedVariable,
263-
),
259+
TensorConstant | TensorVariable | TensorSharedVariable,
264260
):
265261
if factor.ndim == 2 and diag:
266262
factor_list.append(pt.diag(factor))
@@ -524,7 +520,7 @@ def __init__(
524520
if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None):
525521
raise ValueError("Only one of 'ls' or 'ls_inv' must be provided")
526522
elif ls_inv is not None:
527-
if isinstance(ls_inv, (list, tuple)):
523+
if isinstance(ls_inv, list | tuple):
528524
ls = 1.0 / np.asarray(ls_inv)
529525
else:
530526
ls = 1.0 / ls_inv

pymc/gp/hsgp_approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def prior_linearized(self, Xs: TensorLike):
328328

329329
# If not provided, use Xs and c to set L
330330
if self._L is None:
331-
assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable))
331+
assert isinstance(self._c, numbers.Real | np.ndarray | pt.TensorVariable)
332332
self.L = pt.as_tensor(set_boundary(Xs, self._c))
333333
else:
334334
self.L = self._L

pymc/gp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def kmeans_inducing_points(n_inducing, X, **kmeans_kwargs):
113113
# first whiten X
114114
if isinstance(X, TensorConstant):
115115
X = X.value
116-
elif isinstance(X, (np.ndarray, tuple, list)):
116+
elif isinstance(X, np.ndarray | tuple | list):
117117
X = np.asarray(X)
118118
else:
119119
raise TypeError(

pymc/logprob/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _logprob_helper(rv, *values, **kwargs):
7171
if (not name) and (len(values) == 1):
7272
name = values[0].name
7373
if name:
74-
if isinstance(logprob, (list, tuple)):
74+
if isinstance(logprob, list | tuple):
7575
for i, term in enumerate(logprob):
7676
term.name = f"{name}_logprob.{i}"
7777
else:

pymc/logprob/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _find_unallowed_rvs_in_graph(graph):
7878
return {
7979
rv
8080
for rv in rvs_in_graph(graph)
81-
if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV))
81+
if not isinstance(rv.owner.op, SimulatorRV | MinibatchIndexRV)
8282
}
8383

8484

@@ -545,7 +545,7 @@ def conditional_logp(
545545
**kwargs,
546546
)
547547

548-
if not isinstance(q_logprob_vars, (list, tuple)):
548+
if not isinstance(q_logprob_vars, list | tuple):
549549
q_logprob_vars = [q_logprob_vars]
550550

551551
for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars):

pymc/logprob/binary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
102102

103103
condn_exp = pt.eq(value, np.array(True))
104104

105-
if isinstance(op.scalar_op, (GT, GE)):
105+
if isinstance(op.scalar_op, GT | GE):
106106
logprob = pt.switch(condn_exp, logccdf, logcdf)
107-
elif isinstance(op.scalar_op, (LT, LE)):
107+
elif isinstance(op.scalar_op, LT | LE):
108108
logprob = pt.switch(condn_exp, logcdf, logccdf)
109109
else:
110110
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

pymc/logprob/mixture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def get_stack_mixture_vars(
248248
joined_rvs = node.inputs[0]
249249

250250
# First, make sure that it's some sort of concatenation
251-
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
251+
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, MakeVector | Join)):
252252
return None, None
253253

254254
if isinstance(joined_rvs.owner.op, MakeVector):
@@ -284,7 +284,7 @@ def find_measurable_index_mixture(fgraph, node):
284284
mixing_indices = node.inputs[1:]
285285

286286
# TODO: Add check / test case for Advanced Boolean indexing
287-
if isinstance(node.op, (AdvancedSubtensor, AdvancedSubtensor1)):
287+
if isinstance(node.op, AdvancedSubtensor | AdvancedSubtensor1):
288288
# We don't support (non-scalar) integer array indexing as it can pick repeated values,
289289
# but the Mixture logprob assumes all mixture values are independent
290290
if any(
@@ -298,7 +298,7 @@ def find_measurable_index_mixture(fgraph, node):
298298
mixture_rvs, join_axis = get_stack_mixture_vars(node)
299299

300300
# We don't support symbolic join axis
301-
if mixture_rvs is None or not isinstance(join_axis, (NoneTypeT, Constant)):
301+
if mixture_rvs is None or not isinstance(join_axis, NoneTypeT | Constant):
302302
return None
303303

304304
if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs:

pymc/logprob/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ class PowerTransform(Transform):
779779
name = "power"
780780

781781
def __init__(self, power=None):
782-
if not isinstance(power, (int, float)):
782+
if not isinstance(power, int | float):
783783
raise TypeError(f"Power must be integer or float, got {type(power)}")
784784
if power == 0:
785785
raise ValueError("Power cannot be 0")

pymc/logprob/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def expand(r):
147147
return {
148148
node
149149
for node in walk(makeiter(vars), expand, False)
150-
if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable))
150+
if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableVariable)
151151
}
152152

153153

pymc/model/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def resolve_type(c: type | str) -> type:
192192
assert cls is not None
193193
if isinstance(cls._context_class, str):
194194
cls._context_class = resolve_type(cls._context_class)
195-
if not isinstance(cls._context_class, (str, type)):
195+
if not isinstance(cls._context_class, str | type):
196196
raise ValueError(
197197
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type"
198198
)
@@ -695,7 +695,7 @@ def logp(
695695
varlist: list[TensorVariable]
696696
if vars is None:
697697
varlist = self.free_RVs + self.observed_RVs + self.potentials
698-
elif not isinstance(vars, (list, tuple)):
698+
elif not isinstance(vars, list | tuple):
699699
varlist = [vars]
700700
else:
701701
varlist = cast(list[TensorVariable], vars)
@@ -770,7 +770,7 @@ def dlogp(
770770
if vars is None:
771771
value_vars = None
772772
else:
773-
if not isinstance(vars, (list, tuple)):
773+
if not isinstance(vars, list | tuple):
774774
vars = [vars]
775775

776776
value_vars = []
@@ -809,7 +809,7 @@ def d2logp(
809809
if vars is None:
810810
value_vars = None
811811
else:
812-
if not isinstance(vars, (list, tuple)):
812+
if not isinstance(vars, list | tuple):
813813
vars = [vars]
814814

815815
value_vars = []
@@ -998,7 +998,7 @@ def add_coord(
998998
if name in self.coords:
999999
if not np.array_equal(values, self.coords[name]):
10001000
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
1001-
if length is not None and not isinstance(length, (int, Variable)):
1001+
if length is not None and not isinstance(length, int | Variable):
10021002
raise ValueError(
10031003
f"The `length` passed for the '{name}' coord must be an int, PyTensor Variable or None."
10041004
)
@@ -1070,7 +1070,7 @@ def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.nd
10701070

10711071
def set_initval(self, rv_var, initval):
10721072
"""Sets an initial value (strategy) for a random variable."""
1073-
if initval is not None and not isinstance(initval, (Variable, str)):
1073+
if initval is not None and not isinstance(initval, Variable | str):
10741074
# Convert scalars or array-like inputs to ndarrays
10751075
initval = rv_var.type.filter(initval)
10761076

pymc/model/fgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def fgraph_from_model(
260260
inverse_memo = {v: k for k, v in memo.items()}
261261
for var, model_var in replacements:
262262
if not inlined_views and (
263-
model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed))
263+
model_var.owner and isinstance(model_var.owner.op, ModelDeterministic | ModelNamed)
264264
):
265265
# Ignore extra identity that will be removed at the end
266266
var = var.owner.inputs[0]

pymc/model/transform/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
5555

5656

5757
def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> list[Variable]:
58-
if isinstance(vars, (list, tuple)):
58+
if isinstance(vars, list | tuple):
5959
vars_seq = vars
6060
else:
6161
vars_seq = (vars,)

pymc/model/transform/conditioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def observe(
106106
model_var = memo[var]
107107

108108
# Just a sanity check
109-
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic))
109+
assert isinstance(model_var.owner.op, ModelFreeRV | ModelDeterministic)
110110
assert model_var in fgraph.variables
111111

112112
var = model_var.owner.inputs[0]

pymc/ode/ode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def make_node(self, y0, theta):
149149
return Apply(self, inputs, (states, sens))
150150

151151
def __call__(self, y0, theta, return_sens=False, **kwargs):
152-
if isinstance(y0, (list, tuple)) and not len(y0) == self.n_states:
152+
if isinstance(y0, list | tuple) and not len(y0) == self.n_states:
153153
raise ShapeError("Length of y0 is wrong.", actual=(len(y0),), expected=(self.n_states,))
154-
if isinstance(theta, (list, tuple)) and not len(theta) == self.n_theta:
154+
if isinstance(theta, list | tuple) and not len(theta) == self.n_theta:
155155
raise ShapeError(
156156
"Length of theta is wrong.", actual=(len(theta),), expected=(self.n_theta,)
157157
)

pymc/ode/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def augment_system(ode_func, n_states, n_theta):
107107
t_yhat = pt.atleast_1d(yhat)
108108
else:
109109
# Stack the results of the ode_func into a single tensor variable
110-
if not isinstance(yhat, (list, tuple)):
110+
if not isinstance(yhat, list | tuple):
111111
raise TypeError(
112112
f"Unexpected type, {type(yhat)}, returned by ode_func. TensorVariable, list or tuple is expected."
113113
)

0 commit comments

Comments
 (0)