diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index d25c9e16ea..fa30e10c18 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -174,7 +174,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): new_props_dict = op._props_dict().copy() # Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)" # I.e., we substitute the first `()` by `(a)` - new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1) + new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, count=1) new_op = type(op)(**new_props_dict) return new_op.make_node(rng, size, a_vector_param, *other_params).outputs