Skip to content

Commit d5345d2

Browse files
committed
Introduce signature instead of ndim_supp and ndims_params
1 parent da5f7f0 commit d5345d2

File tree

9 files changed

+198
-222
lines changed

9 files changed

+198
-222
lines changed

pytensor/tensor/random/basic.py

Lines changed: 57 additions & 108 deletions
Large diffs are not rendered by default.

pytensor/tensor/random/op.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Sequence
23
from copy import copy
34
from typing import cast
@@ -28,6 +29,7 @@
2829
from pytensor.tensor.shape import shape_tuple
2930
from pytensor.tensor.type import TensorType, all_dtypes
3031
from pytensor.tensor.type_other import NoneConst
32+
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
3133
from pytensor.tensor.variable import TensorVariable
3234

3335

@@ -42,61 +44,81 @@ class RandomVariable(Op):
4244

4345
_output_type_depends_on_input_value = True
4446

45-
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
47+
__props__ = ("name", "signature", "dtype", "inplace")
4648
default_output = 1
4749

4850
def __init__(
4951
self,
5052
name=None,
5153
ndim_supp=None,
5254
ndims_params=None,
53-
dtype=None,
55+
dtype: str | None = None,
5456
inplace=None,
57+
signature: str | None = None,
5558
):
5659
"""Create a random variable `Op`.
5760
5861
Parameters
5962
----------
6063
name: str
6164
The `Op`'s display name.
62-
ndim_supp: int
63-
Total number of dimensions for a single draw of the random variable
64-
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
65-
ndims_params: list of int
66-
Number of dimensions for each distribution parameter when the
67-
parameters only specify a single drawn of the random variable
68-
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
69-
``ndims_params = [1, 2]``).
65+
signature: str
66+
Numpy-like vectorized signature of the random variable.
7067
dtype: str (optional)
7168
The dtype of the sampled output. If the value ``"floatX"`` is
7269
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
7370
``None`` (the default), the `dtype` keyword must be set when
7471
`RandomVariable.make_node` is called.
7572
inplace: boolean (optional)
76-
Determine whether or not the underlying rng state is updated
77-
in-place or not (i.e. copied).
73+
Determine whether the underlying rng state is mutated or copied.
7874
7975
"""
8076
super().__init__()
8177

8278
self.name = name or getattr(self, "name")
83-
self.ndim_supp = (
84-
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp")
79+
80+
ndim_supp = (
81+
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None)
8582
)
86-
self.ndims_params = (
87-
ndims_params if ndims_params is not None else getattr(self, "ndims_params")
83+
if ndim_supp is not None:
84+
warnings.warn(
85+
"ndim_supp is deprecated. Provide signature instead.", FutureWarning
86+
)
87+
self.ndim_supp = ndim_supp
88+
ndims_params = (
89+
ndims_params
90+
if ndims_params is not None
91+
else getattr(self, "ndims_params", None)
8892
)
93+
if ndims_params is not None:
94+
warnings.warn(
95+
"ndims_params is deprecated. Provide signature instead.", FutureWarning
96+
)
97+
if not isinstance(ndims_params, Sequence):
98+
raise TypeError("Parameter ndims_params must be sequence type.")
99+
self.ndims_params = tuple(ndims_params)
100+
101+
self.signature = signature or getattr(self, "signature", None)
102+
if self.signature is not None:
103+
# Assume a single output. Several methods need to be updated to handle multiple outputs.
104+
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
105+
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
106+
self.ndim_supp = len(self.output_sig)
107+
else:
108+
if (
109+
getattr(self, "ndim_supp", None) is None
110+
or getattr(self, "ndims_params", None) is None
111+
):
112+
raise ValueError("signature must be provided")
113+
else:
114+
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
115+
89116
self.dtype = dtype or getattr(self, "dtype", None)
90117

91118
self.inplace = (
92119
inplace if inplace is not None else getattr(self, "inplace", False)
93120
)
94121

95-
if not isinstance(self.ndims_params, Sequence):
96-
raise TypeError("Parameter ndims_params must be sequence type.")
97-
98-
self.ndims_params = tuple(self.ndims_params)
99-
100122
if self.inplace:
101123
self.destroy_map = {0: [0]}
102124

@@ -120,16 +142,56 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
120142
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
121143
might have `support_shape=(steps,)`.
122144
"""
145+
if self.signature is not None:
146+
# Signature could indicate fixed numerical shapes
147+
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
148+
output_sig = self.output_sig
149+
core_out_shape = {
150+
dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig
151+
}
152+
153+
# Try to infer missing support dims from signature of params
154+
for param, param_sig, ndim_params in zip(
155+
dist_params, self.inputs_sig, self.ndims_params
156+
):
157+
if ndim_params == 0:
158+
continue
159+
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
160+
if dim in core_out_shape and core_out_shape[dim] is None:
161+
core_out_shape[dim] = param_dim
162+
163+
if all(dim is not None for dim in core_out_shape.values()):
164+
# We have all we need
165+
return [core_out_shape[dim] for dim in output_sig]
166+
123167
raise NotImplementedError(
124-
"`_supp_shape_from_params` must be implemented for multivariate RVs"
168+
"`_supp_shape_from_params` must be implemented for multivariate RVs "
169+
"when signature is not sufficient to infer the support shape"
125170
)
126171

127172
def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray:
128173
"""Sample a numeric random variate."""
129174
return getattr(rng, self.name)(*args, **kwargs)
130175

131176
def __str__(self):
132-
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:])
177+
# Only show signature from core props
178+
if signature := self.signature:
179+
# inp, out = signature.split("->")
180+
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
181+
# core_props = [extended_signature]
182+
core_props = [f'"{signature}"']
183+
else:
184+
# Far back compat
185+
core_props = [str(self.ndim_supp), str(self.ndims_params)]
186+
187+
# Add any extra props that the subclass may have
188+
extra_props = [
189+
str(getattr(self, prop))
190+
for prop in self.__props__
191+
if prop not in RandomVariable.__props__
192+
]
193+
194+
props_str = ", ".join(core_props + extra_props)
133195
return f"{self.name}_rv{{{props_str}}}"
134196

135197
def _infer_shape(
@@ -298,11 +360,11 @@ def make_node(self, rng, size, dtype, *dist_params):
298360
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
299361
else:
300362
dtype_idx = constant(dtype, dtype="int64")
301-
dtype = all_dtypes[dtype_idx.data]
302363

303-
outtype = TensorType(dtype=dtype, shape=static_shape)
304-
out_var = outtype()
364+
dtype = all_dtypes[dtype_idx.data]
365+
305366
inputs = (rng, size, dtype_idx, *dist_params)
367+
out_var = TensorType(dtype=dtype, shape=static_shape)()
306368
outputs = (rng.type(), out_var)
307369

308370
return Apply(self, inputs, outputs)
@@ -395,9 +457,8 @@ def vectorize_random_variable(
395457
# We extend it to accommodate the new input batch dimensions.
396458
# Otherwise, we assume the new size already has the right values
397459

398-
# Need to make parameters implicit broadcasting explicit
399-
original_dist_params = node.inputs[3:]
400-
old_size = node.inputs[1]
460+
original_dist_params = op.dist_params(node)
461+
old_size = op.size_param(node)
401462
len_old_size = get_vector_length(old_size)
402463

403464
original_expanded_dist_params = explicit_expand_dims(

pytensor/tensor/random/rewriting/jax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
from pytensor.compile import optdb
24
from pytensor.graph.rewriting.basic import in2out, node_rewriter
35
from pytensor.graph.rewriting.db import SequenceDB
@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
164166

165167
a_vector_param = arange(a_scalar_param)
166168
new_props_dict = op._props_dict().copy()
167-
new_ndims_params = list(op.ndims_params)
168-
new_ndims_params[0] += 1
169-
new_props_dict["ndims_params"] = new_ndims_params
169+
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
170+
# I.e., we substitute the first `()` by `(a)`
171+
new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
170172
new_op = type(op)(**new_props_dict)
171173
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
172174

pytensor/tensor/random/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
123123

124124
def explicit_expand_dims(
125125
params: Sequence[TensorVariable],
126-
ndim_params: tuple[int],
126+
ndim_params: Sequence[int],
127127
size_length: int = 0,
128128
) -> list[TensorVariable]:
129129
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
@@ -137,7 +137,7 @@ def explicit_expand_dims(
137137
# See: https://github.com/pymc-devs/pytensor/issues/568
138138
max_batch_dims = size_length
139139
else:
140-
max_batch_dims = max(batch_dims)
140+
max_batch_dims = max(batch_dims, default=0)
141141

142142
new_params = []
143143
for new_param, batch_dim in zip(params, batch_dims):
@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
354354
out: tuple
355355
Representing the support shape for a `RandomVariable` with the given `dist_params`.
356356
357+
Notes
358+
_____
359+
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
360+
361+
357362
"""
358363
if ndim_supp <= 0:
359364
raise ValueError("ndim_supp must be greater than 0")

pytensor/tensor/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def broadcast_static_dim_lengths(
169169
_CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
170170
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
171171
_ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
172-
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
172+
# Allow no inputs
173+
_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
173174

174175

175176
def _parse_gufunc_signature(
@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
200201
tuple(re.findall(_DIMENSION_NAME, arg))
201202
for arg in re.findall(_ARGUMENT, arg_list)
202203
]
204+
if arg_list # ignore no inputs
205+
else []
203206
for arg_list in signature.split("->")
204207
)
205208

tests/link/jax/test_random.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,7 @@ def test_random_unimplemented():
771771

772772
class NonExistentRV(RandomVariable):
773773
name = "non-existent"
774-
ndim_supp = 0
775-
ndims_params = []
774+
signature = "->()"
776775
dtype = "floatX"
777776

778777
def __call__(self, size=None, **kwargs):
@@ -798,8 +797,7 @@ def test_random_custom_implementation():
798797

799798
class CustomRV(RandomVariable):
800799
name = "non-existent"
801-
ndim_supp = 0
802-
ndims_params = []
800+
signature = "->()"
803801
dtype = "floatX"
804802

805803
def __call__(self, size=None, **kwargs):

tests/tensor/random/rewriting/test_basic.py

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
7474
return new_out, f_inputs, dist_st, f_rewritten
7575

7676

77-
def test_inplace_rewrites():
78-
out = normal(0, 1)
79-
out.owner.inputs[0].default_update = out.owner.outputs[0]
77+
class TestRVExpraProps(RandomVariable):
78+
name = "test"
79+
signature = "()->()"
80+
__props__ = ("name", "signature", "dtype", "inplace", "extra")
81+
dtype = "floatX"
82+
_print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}")
8083

81-
assert out.owner.op.inplace is False
84+
def __init__(self, extra, *args, **kwargs):
85+
self.extra = extra
86+
super().__init__(*args, **kwargs)
8287

83-
f = function(
84-
[],
85-
out,
86-
mode="FAST_RUN",
87-
)
88-
89-
(new_out, new_rng) = f.maker.fgraph.outputs
90-
assert new_out.type == out.type
91-
assert isinstance(new_out.owner.op, type(out.owner.op))
92-
assert new_out.owner.op.inplace is True
93-
assert all(
94-
np.array_equal(a.data, b.data)
95-
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
96-
)
97-
assert np.array_equal(new_out.owner.inputs[1].data, [])
98-
99-
100-
def test_inplace_rewrites_extra_props():
101-
class Test(RandomVariable):
102-
name = "test"
103-
ndim_supp = 0
104-
ndims_params = [0]
105-
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
106-
dtype = "floatX"
107-
_print_name = ("Test", "\\operatorname{Test}")
108-
109-
def __init__(self, extra, *args, **kwargs):
110-
self.extra = extra
111-
super().__init__(*args, **kwargs)
112-
113-
def make_node(self, rng, size, dtype, sigma):
114-
return super().make_node(rng, size, dtype, sigma)
115-
116-
def rng_fn(self, rng, sigma, size):
117-
return rng.normal(scale=sigma, size=size)
88+
def rng_fn(self, rng, dtype, sigma, size):
89+
return rng.normal(scale=sigma, size=size)
11890

119-
out = Test(extra="some value")(1)
120-
out.owner.inputs[0].default_update = out.owner.outputs[0]
12191

122-
assert out.owner.op.inplace is False
92+
@pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")])
93+
def test_inplace_rewrites(rv_op):
94+
out = rv_op(np.e)
95+
node = out.owner
96+
op = node.op
97+
node.inputs[0].default_update = node.outputs[0]
98+
assert op.inplace is False
12399

124100
f = function(
125101
[],
@@ -129,9 +105,10 @@ def rng_fn(self, rng, sigma, size):
129105

130106
(new_out, new_rng) = f.maker.fgraph.outputs
131107
assert new_out.type == out.type
132-
assert isinstance(new_out.owner.op, type(out.owner.op))
133-
assert new_out.owner.op.inplace is True
134-
assert new_out.owner.op.extra == out.owner.op.extra
108+
new_node = new_out.owner
109+
new_op = new_node.op
110+
assert isinstance(new_op, type(op))
111+
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
135112
assert all(
136113
np.array_equal(a.data, b.data)
137114
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])

0 commit comments

Comments
 (0)