Skip to content

Commit 8aeda39

Browse files
ArmavicaricardoV94
authored andcommitted
Fix RUF005
Automated fixes by RUF, and update of the TensorConstructorType type in scan/op.py because mypy didn't like something there.
1 parent b142cb5 commit 8aeda39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+392
-330
lines changed

pytensor/breakpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def make_node(self, condition, *monitored_vars):
9292
new_op.inp_types.append(monitored_vars[i].type)
9393

9494
# Build the Apply node
95-
inputs = [condition] + list(monitored_vars)
95+
inputs = [condition, *list(monitored_vars)]
9696
outputs = [inp.type() for inp in monitored_vars]
9797
return Apply(op=new_op, inputs=inputs, outputs=outputs)
9898

@@ -142,7 +142,7 @@ def perform(self, node, inputs, output_storage):
142142
output_storage[i][0] = inputs[i + 1]
143143

144144
def grad(self, inputs, output_gradients):
145-
return [DisconnectedType()()] + output_gradients
145+
return [DisconnectedType()(), *output_gradients]
146146

147147
def infer_shape(self, fgraph, inputs, input_shapes):
148148
# Return the shape of every input but the condition (first input)

pytensor/compile/debugmode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,9 +892,9 @@ def _get_preallocated_maps(
892892

893893
# Use the same step on all dimensions before the last check_ndim.
894894
if all(s == 1 for s in out_shape[:-check_ndim]):
895-
step_signs_list = [(1,)] + step_signs_list
895+
step_signs_list = [(1,), *step_signs_list]
896896
else:
897-
step_signs_list = [(-1, 1)] + step_signs_list
897+
step_signs_list = [(-1, 1), *step_signs_list]
898898

899899
for step_signs in itertools_product(*step_signs_list):
900900
for step_size in (1, 2):

pytensor/compile/sharedvalue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
209209
add_tag_trace(var)
210210
return var
211211
except MemoryError as e:
212-
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
212+
e.args = (*e.args, "Consider using `pytensor.shared(..., borrow=True)`")
213213
raise
214214

215215

pytensor/configdefaults.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,8 @@ def add_caching_dir_configvars():
13821382
"fft_tiling",
13831383
"winograd",
13841384
"winograd_non_fused",
1385-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1385+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1386+
)
13861387

13871388
SUPPORTED_DNN_CONV_ALGO_BWD_DATA = (
13881389
"none",
@@ -1391,7 +1392,8 @@ def add_caching_dir_configvars():
13911392
"fft_tiling",
13921393
"winograd",
13931394
"winograd_non_fused",
1394-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1395+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1396+
)
13951397

13961398
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER = (
13971399
"none",
@@ -1400,7 +1402,8 @@ def add_caching_dir_configvars():
14001402
"small",
14011403
"winograd_non_fused",
14021404
"fft_tiling",
1403-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1405+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1406+
)
14041407

14051408
SUPPORTED_DNN_CONV_PRECISION = (
14061409
"as_input_f32",

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ def inner_function(*args):
19721972
jacobs, updates = pytensor.scan(
19731973
inner_function,
19741974
sequences=pytensor.tensor.arange(expression.shape[0]),
1975-
non_sequences=[expression] + wrt,
1975+
non_sequences=[expression, *wrt],
19761976
)
19771977
assert not updates, "Scan has returned a list of updates; this should not happen."
19781978
return as_list_or_tuple(using_list, using_tuple, jacobs)

pytensor/graph/features.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,13 @@ def consistent_(self, fgraph):
508508

509509

510510
class ReplaceValidate(History, Validator):
511-
pickle_rm_attr = (
512-
["replace_validate", "replace_all_validate", "replace_all_validate_remove"]
513-
+ History.pickle_rm_attr
514-
+ Validator.pickle_rm_attr
515-
)
511+
pickle_rm_attr = [
512+
"replace_validate",
513+
"replace_all_validate",
514+
"replace_all_validate_remove",
515+
*History.pickle_rm_attr,
516+
*Validator.pickle_rm_attr,
517+
]
516518

517519
def on_attach(self, fgraph):
518520
for attr in (

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def print_profile(cls, stream, prof, level=0):
405405
else:
406406
name = rewrite.name
407407
idx = rewrites.index(rewrite)
408-
ll.append((name, rewrite.__class__.__name__, idx) + nb_n)
408+
ll.append((name, rewrite.__class__.__name__, idx, *nb_n))
409409
lll = sorted(zip(prof, ll), key=lambda a: a[0])
410410

411411
for t, rewrite in lll[::-1]:
@@ -1138,7 +1138,8 @@ def decorator(f):
11381138
req = requirements
11391139
if inplace:
11401140
dh_handler = dh.DestroyHandler
1141-
req = tuple(requirements) + (
1141+
req = (
1142+
*tuple(requirements),
11421143
lambda fgraph: fgraph.attach_feature(dh_handler()),
11431144
)
11441145
rval = FromFunctionNodeRewriter(f, tracks, req)

pytensor/graph/rewriting/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def register(
7373

7474
if use_db_name_as_tag:
7575
if self.name is not None:
76-
tags = tags + (self.name,)
76+
tags = (*tags, self.name)
7777

7878
rewriter.name = name
7979
# This restriction is there because in many place we suppose that

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(self, *args, **kwargs):
171171
assert list(kwargs.keys()) == ["variable"]
172172
error_msg = get_variable_trace_string(kwargs["variable"])
173173
if error_msg:
174-
args = args + (error_msg,)
174+
args = (*args, error_msg)
175175
s = "\n".join(args) # Needed to have the new line print correctly
176176
super().__init__(s)
177177

pytensor/ifelse.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
227227

228228
return Apply(
229229
self,
230-
[condition] + new_inputs_true_branch + new_inputs_false_branch,
230+
[condition, *new_inputs_true_branch, *new_inputs_false_branch],
231231
output_vars,
232232
)
233233

@@ -275,11 +275,11 @@ def grad(self, ins, grads):
275275
# condition + epsilon always triggers the same branch as condition
276276
condition_grad = condition.zeros_like().astype(config.floatX)
277277

278-
return (
279-
[condition_grad]
280-
+ if_true_op(*inputs_true_grad, return_list=True)
281-
+ if_false_op(*inputs_false_grad, return_list=True)
282-
)
278+
return [
279+
condition_grad,
280+
*if_true_op(*inputs_true_grad, return_list=True),
281+
*if_false_op(*inputs_false_grad, return_list=True),
282+
]
283283

284284
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
285285
cond = node.inputs[0]
@@ -397,7 +397,7 @@ def ifelse(
397397

398398
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name)
399399

400-
ins = [condition] + list(then_branch) + list(else_branch)
400+
ins = [condition, *list(then_branch), *list(else_branch)]
401401
rval = new_ifelse(*ins, return_list=True)
402402

403403
if rval_type is None:
@@ -611,7 +611,7 @@ def apply(self, fgraph):
611611
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
612612
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
613613
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
614-
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
614+
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
615615
mn_name = "?"
616616
if merging_node.op.name:
617617
mn_name = merging_node.op.name
@@ -673,7 +673,7 @@ def cond_remove_identical(fgraph, node):
673673

674674
new_ifelse = IfElse(n_outs=len(nw_ts), as_view=op.as_view, name=op.name)
675675

676-
new_ins = [node.inputs[0]] + nw_ts + nw_fs
676+
new_ins = [node.inputs[0], *nw_ts, *nw_fs]
677677
new_outs = new_ifelse(*new_ins, return_list=True)
678678

679679
rval = []
@@ -711,7 +711,7 @@ def cond_merge_random_op(fgraph, main_node):
711711
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
712712
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
713713
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
714-
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
714+
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
715715
mn_name = "?"
716716
if merging_node.op.name:
717717
mn_name = merging_node.op.name

pytensor/link/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __set__(self, value: Any) -> None:
106106
self.storage[0] = self.type.filter(value, **kwargs)
107107

108108
except Exception as e:
109-
e.args = e.args + (f'Container name "{self.name}"',)
109+
e.args = (*e.args, f'Container name "{self.name}"')
110110
raise
111111

112112
data = property(__get__, __set__)

pytensor/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ def _try_compile_tmp(
18901890
os.close(fd)
18911891
fd = None
18921892
out, err, p_ret = output_subprocess_Popen(
1893-
[compiler] + args + [path, "-o", exe_path] + flags
1893+
[compiler, *args, path, "-o", exe_path, *flags]
18941894
)
18951895
if p_ret != 0:
18961896
compilation_ok = False

pytensor/link/c/params_type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __hash__(self):
293293
.signature()
294294
for i in range(self.__params_type__.length)
295295
)
296-
return hash((type(self), self.__params_type__) + self.__signatures__)
296+
return hash((type(self), self.__params_type__, *self.__signatures__))
297297

298298
def __eq__(self, other):
299299
return (
@@ -437,7 +437,7 @@ def __eq__(self, other):
437437
)
438438

439439
def __hash__(self):
440-
return hash((type(self),) + self.fields + self.types)
440+
return hash((type(self), *self.fields, *self.types))
441441

442442
def generate_struct_name(self):
443443
# This method tries to generate an unique name for the current instance.
@@ -807,7 +807,7 @@ def c_support_code(self, **kwargs):
807807
)
808808
)
809809

810-
return sorted(c_support_code_set) + [final_struct_code]
810+
return [*sorted(c_support_code_set), final_struct_code]
811811

812812
def c_code_cache_version(self):
813813
return ((3,), tuple(t.c_code_cache_version() for t in self.types))

pytensor/link/c/type.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def c_compile_args(self, **kwargs):
267267
def c_code_cache_version(self):
268268
v = (3,)
269269
if self.version is not None:
270-
v = v + (self.version,)
270+
v = (*v, self.version)
271271
return v
272272

273273
def __str__(self):
@@ -505,9 +505,12 @@ def __delitem__(self, key):
505505
def __hash__(self):
506506
# All values are Python basic types, then easy to hash.
507507
return hash(
508-
(type(self), self.ctype)
509-
+ tuple((k, self[k]) for k in sorted(self.keys()))
510-
+ tuple((a, self.aliases[a]) for a in sorted(self.aliases.keys()))
508+
(
509+
type(self),
510+
self.ctype,
511+
*tuple((k, self[k]) for k in sorted(self.keys())),
512+
*tuple((a, self.aliases[a]) for a in sorted(self.aliases.keys())),
513+
)
511514
)
512515

513516
def __eq__(self, other):

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def maxandargmax(x, axis=axis):
130130

131131
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
132132
# Otherwise reshape would complain citing float arg
133-
new_shape = kept_shape + (
133+
new_shape = (
134+
*kept_shape,
134135
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
135136
)
136137
reshaped_x = transposed_x.reshape(new_shape)

pytensor/link/jax/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def jax_funcify_ScalarOp(op, node, **kwargs):
9595

9696
func_name = nfunc_spec[0]
9797
if "." in func_name:
98-
jax_func = functools.reduce(getattr, [jax] + func_name.split("."))
98+
jax_func = functools.reduce(getattr, [jax, *func_name.split(".")])
9999
else:
100100
jax_func = getattr(jnp, func_name)
101101

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def creator(args):
330330

331331
@numba_njit
332332
def creator(args, creator=creator, i=i):
333-
return creator(args) + (f(i, *args),)
333+
return (*creator(args), f(i, *args))
334334

335335
return numba_njit(lambda *args: creator(args))
336336

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def jit_compile_reducer(
447447
def create_axis_apply_fn(fn, axis, ndim, dtype):
448448
axis = normalize_axis_index(axis, ndim)
449449

450-
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
450+
reaxis_first = (*tuple(i for i in range(ndim) if i != axis), axis)
451451

452452
@numba_basic.numba_njit(boundscheck=False)
453453
def axis_apply_fn(x):
@@ -1042,7 +1042,7 @@ def maxandargmax(x):
10421042

10431043
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
10441044
# Otherwise reshape would complain citing float arg
1045-
new_shape = kept_shape + (reduced_size,)
1045+
new_shape = (*kept_shape, reduced_size)
10461046
reshaped_x = transposed_x.reshape(new_shape)
10471047

10481048
max_idx_res = argmax_axis(reshaped_x)

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
4444
if axis < 0 or axis >= ndim:
4545
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
4646

47-
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
47+
reaxis_first = (axis, *tuple(i for i in range(ndim) if i != axis))
4848
reaxis_first_inv = tuple(np.argsort(reaxis_first))
4949

5050
if mode == "add":

pytensor/link/numba/dispatch/random.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,7 @@ def create_numba_random_fn(
240240
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
241241

242242
unique_names = unique_name_generator(
243-
[
244-
np_random_fn_name,
245-
]
246-
+ list(np_global_env.keys())
247-
+ [
248-
"rng",
249-
"size",
250-
"dtype",
251-
],
243+
[np_random_fn_name, *list(np_global_env.keys()), "rng", "size", "dtype"],
252244
suffix_sep="_",
253245
)
254246

pytensor/link/numba/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def {scalar_op_fn_name}({input_names}):
115115
global_env.update(input_tmp_dtype_names)
116116

117117
unique_names = unique_name_generator(
118-
[scalar_op_fn_name, "scalar_func_numba"] + list(global_env.keys()),
118+
[scalar_op_fn_name, "scalar_func_numba", *list(global_env.keys())],
119119
suffix_sep="_",
120120
)
121121

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def _check_scipy_linalg_matrix(a, func_name):
144144
msg = "{}.{}() only supported for array types".format(*interp)
145145
raise numba.TypingError(msg, highlighting=False)
146146
if a.ndim not in [1, 2]:
147-
msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % (
148-
interp + (a.ndim,)
147+
msg = "{}.{}() only supported on 1d or 2d arrays, found {}.".format(
148+
*interp, a.ndim
149149
)
150150
raise numba.TypingError(msg, highlighting=False)
151151
if not isinstance(a.dtype, (types.Float, types.Complex)):

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,14 @@ def extract_diag(x):
174174
else:
175175
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
176176
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
177-
out_shape = base_shape + (diag_len,)
177+
out_shape = (*base_shape, diag_len)
178178
out = np.empty(out_shape)
179179

180180
for i in range(diag_len):
181181
if offset >= 0:
182-
new_entry = x[leading_dims + (i,) + middle_dims + (i + offset,)]
182+
new_entry = x[(*leading_dims, i, *middle_dims, i + offset)]
183183
else:
184-
new_entry = x[leading_dims + (i - offset,) + middle_dims + (i,)]
184+
new_entry = x[(*leading_dims, i - offset, *middle_dims, i)]
185185
out[..., i] = new_entry
186186
return out
187187

pytensor/raise_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def make_node(self, value: Variable, *conds: Variable):
8787

8888
return Apply(
8989
self,
90-
[value] + conds,
90+
[value, *conds],
9191
[value.type()],
9292
)
9393

0 commit comments

Comments
 (0)