Skip to content

Commit 45eb6ef

Browse files
committed
Add gufunc_signature to SymbolicRandomVariables
1 parent a2988c7 commit 45eb6ef

File tree

6 files changed

+171
-36
lines changed

6 files changed

+171
-36
lines changed

pymc/distributions/censored.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable):
3030
"""Censored random variable"""
3131

3232
inline_logprob = True
33+
signature = "(),(),()->()"
34+
ndim_supp = 0
3335
_print_name = ("Censored", "\\operatorname{Censored}")
3436

3537

@@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
115117
return CensoredRV(
116118
inputs=[dist_, lower_, upper_],
117119
outputs=[censored_rv_],
118-
ndim_supp=0,
119120
)(dist, lower, upper)
120121

121122

pymc/distributions/distribution.py

Lines changed: 114 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
from pytensor.graph.utils import MetaType
3434
from pytensor.scan.op import Scan
3535
from pytensor.tensor.basic import as_tensor_variable
36+
from pytensor.tensor.blockwise import safe_signature
3637
from pytensor.tensor.random.op import RandomVariable
3738
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
3839
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
3940
from pytensor.tensor.random.utils import normalize_size_param
4041
from pytensor.tensor.rewriting.shape import ShapeFeature
42+
from pytensor.tensor.utils import _parse_gufunc_signature
4143
from pytensor.tensor.variable import TensorVariable
4244
from typing_extensions import TypeAlias
4345

@@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph):
261263
(0 for scalar, 1 for vector, ...)
262264
"""
263265

266+
ndims_params: Optional[Sequence[int]] = None
267+
"""Number of core dimensions of the distribution's parameters."""
268+
269+
signature: str = None
270+
"""Numpy-like vectorized signature of the distribution."""
271+
264272
inline_logprob: bool = False
265273
"""Specifies whether the logprob function is derived automatically by introspection
266274
of the inner graph.
@@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph):
271279
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
272280
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""
273281

274-
def __init__(self, *args, ndim_supp, **kwargs):
275-
"""Initialitze a SymbolicRandomVariable class."""
276-
self.ndim_supp = ndim_supp
282+
def __init__(
283+
self,
284+
*args,
285+
**kwargs,
286+
):
287+
"""Initialize a SymbolicRandomVariable class."""
288+
if self.signature is None:
289+
self.signature = kwargs.get("signature", None)
290+
291+
if self.signature is not None:
292+
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature)
293+
self.ndims_params = [len(sig) for sig in inputs_sig]
294+
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)
295+
296+
if self.ndim_supp is None:
297+
self.ndim_supp = kwargs.get("ndim_supp", None)
298+
if self.ndim_supp is None:
299+
raise ValueError("ndim_supp or gufunc_signature must be provided")
300+
277301
kwargs.setdefault("inline", True)
278302
super().__init__(*args, **kwargs)
279303

@@ -286,6 +310,11 @@ def update(self, node: Node):
286310
"""
287311
return {}
288312

313+
def batch_ndim(self, node: Node) -> int:
314+
"""Number of dimensions of the distribution's batch shape."""
315+
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
316+
return out_ndim - self.ndim_supp
317+
289318

290319
class Distribution(metaclass=DistributionMeta):
291320
"""Statistical distribution"""
@@ -558,23 +587,29 @@ def dist(
558587
logcdf: Optional[Callable] = None,
559588
random: Optional[Callable] = None,
560589
support_point: Optional[Callable] = None,
561-
ndim_supp: int = 0,
590+
ndim_supp: Optional[int] = None,
562591
ndims_params: Optional[Sequence[int]] = None,
592+
signature: Optional[str] = None,
563593
dtype: str = "floatX",
564594
class_name: str = "CustomDist",
565595
**kwargs,
566596
):
597+
if ndim_supp is None or ndims_params is None:
598+
if signature is None:
599+
ndim_supp = 0
600+
ndims_params = [0] * len(dist_params)
601+
else:
602+
inputs, outputs = _parse_gufunc_signature(signature)
603+
ndim_supp = max(len(out) for out in outputs)
604+
ndims_params = [len(inp) for inp in inputs]
605+
567606
if ndim_supp > 0:
568607
raise NotImplementedError(
569608
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
570609
)
571610

572611
dist_params = [as_tensor_variable(param) for param in dist_params]
573612

574-
# Assume scalar ndims_params
575-
if ndims_params is None:
576-
ndims_params = [0] * len(dist_params)
577-
578613
if logp is None:
579614
logp = default_not_implemented(class_name, "logp")
580615

@@ -614,7 +649,7 @@ def rv_op(
614649
random: Optional[Callable],
615650
support_point: Optional[Callable],
616651
ndim_supp: int,
617-
ndims_params: Optional[Sequence[int]],
652+
ndims_params: Sequence[int],
618653
dtype: str,
619654
class_name: str,
620655
**kwargs,
@@ -702,7 +737,9 @@ def dist(
702737
logp: Optional[Callable] = None,
703738
logcdf: Optional[Callable] = None,
704739
support_point: Optional[Callable] = None,
705-
ndim_supp: int = 0,
740+
ndim_supp: Optional[int] = None,
741+
ndims_params: Optional[Sequence[int]] = None,
742+
signature: Optional[str] = None,
706743
dtype: str = "floatX",
707744
class_name: str = "CustomDist",
708745
**kwargs,
@@ -712,14 +749,24 @@ def dist(
712749
if logcdf is None:
713750
logcdf = default_not_implemented(class_name, "logcdf")
714751

752+
if signature is None:
753+
if ndim_supp is None:
754+
ndim_supp = 0
755+
if ndims_params is None:
756+
ndims_params = [0] * len(dist_params)
757+
signature = safe_signature(
758+
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params],
759+
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)],
760+
)
761+
715762
return super().dist(
716763
dist_params,
717764
class_name=class_name,
718765
logp=logp,
719766
logcdf=logcdf,
720767
dist=dist,
721768
support_point=support_point,
722-
ndim_supp=ndim_supp,
769+
signature=signature,
723770
**kwargs,
724771
)
725772

@@ -732,7 +779,7 @@ def rv_op(
732779
logcdf: Optional[Callable],
733780
support_point: Optional[Callable],
734781
size=None,
735-
ndim_supp: int,
782+
signature: str,
736783
class_name: str,
737784
):
738785
size = normalize_size_param(size)
@@ -745,6 +792,10 @@ def rv_op(
745792
dummy_params = [dummy_size_param, *dummy_dist_params]
746793
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
747794

795+
signature = cls._infer_final_signature(
796+
signature, len(dummy_params), len(dummy_updates_dict)
797+
)
798+
748799
rv_type = type(
749800
class_name,
750801
(CustomSymbolicDistRV,),
@@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
802853
new_rv_op = rv_type(
803854
inputs=dummy_params,
804855
outputs=[*dummy_updates_dict.values(), dummy_rv],
805-
ndim_supp=ndim_supp,
856+
signature=signature,
806857
)
807858
new_rv = new_rv_op(new_size, *dist_params)
808859

@@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
811862
rv_op = rv_type(
812863
inputs=dummy_params,
813864
outputs=[*dummy_updates_dict.values(), dummy_rv],
814-
ndim_supp=ndim_supp,
865+
signature=signature,
815866
)
816867
return rv_op(size, *dist_params)
817868

869+
@staticmethod
870+
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
871+
"""Add size and updates to user provided gufunc signature if they are missing."""
872+
input_sig, output_sig = signature.split("->")
873+
# Numpy parser does not accept (constant) functions without inputs like "->()"
874+
# We work around as this makes sense for distributions like Flat that have no inputs
875+
if input_sig.strip() == "":
876+
inputs = ()
877+
_, outputs = _parse_gufunc_signature("()" + signature)
878+
else:
879+
inputs, outputs = _parse_gufunc_signature(signature)
880+
if len(inputs) == n_inputs - 1:
881+
# Assume size is missing
882+
input_sig = ("()," if input_sig else "()") + input_sig
883+
if len(outputs) == 1:
884+
# Assume updates are missing
885+
output_sig = "()," * n_updates + output_sig
886+
signature = "->".join((input_sig, output_sig))
887+
return signature
888+
818889

819890
class CustomDist:
820891
"""A helper class to create custom distributions
@@ -828,12 +899,12 @@ class CustomDist:
828899
when not provided by the user.
829900
830901
Alternatively, a user can provide a `random` function that returns numerical
831-
draws (e.g., via NumPy routines), and a `logp` function that must return an
832-
Python graph that represents the logp graph when evaluated. This is used for
902+
draws (e.g., via NumPy routines), and a `logp` function that must return a
903+
PyTensor graph that represents the logp graph when evaluated. This is used for
833904
mcmc sampling.
834905
835906
Additionally, a user can provide a `logcdf` and `support_point` functions that must return
836-
an PyTensor graph that computes those quantities. These may be used by other PyMC
907+
PyTensor graphs that computes those quantities. These may be used by other PyMC
837908
routines.
838909
839910
Parameters
@@ -894,14 +965,18 @@ class CustomDist:
894965
distribution parameters, in the same order as they were supplied when the
895966
CustomDist was created. If ``None``, a default ``support_point`` function will be
896967
assigned that will always return 0, or an array of zeros.
897-
ndim_supp : int
898-
The number of dimensions in the support of the distribution. Defaults to assuming
899-
a scalar distribution, i.e. ``ndim_supp = 0``.
968+
ndim_supp : Optional[int]
969+
The number of dimensions in the support of the distribution.
970+
Inferred from signature, if provided. Defaults to assuming
971+
a scalar distribution, i.e. ``ndim_supp = 0``
900972
ndims_params : Optional[Sequence[int]]
901973
The list of number of dimensions in the support of each of the distribution's
902-
parameters. If ``None``, it is assumed that all parameters are scalars, hence
903-
the number of dimensions of their support will be 0. This is not needed if an
904-
PyTensor dist function is provided.
974+
parameters. Inferred from signature, if provided. Defaults to assuming
975+
all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
976+
signature : Optional[str]
977+
A numpy vectorize-like signature that indicates the number and core dimensionality
978+
of the input parameters and sample outputs of the CustomDist.
979+
When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
905980
dtype : str
906981
The dtype of the distribution. All draws and observations passed into the
907982
distribution will be cast onto this dtype. This is not needed if an PyTensor
@@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable:
9391014
9401015
Provide a random function that return numerical draws. This allows one to use a
9411016
CustomDist in prior and posterior predictive sampling.
1017+
A gufunc signature was also provided, which may be used by other routines.
9421018
9431019
.. code-block:: python
9441020
@@ -965,6 +1041,7 @@ def random(
9651041
mu,
9661042
logp=logp,
9671043
random=random,
1044+
signature="()->()",
9681045
observed=np.random.randn(100, 3),
9691046
size=(100, 3),
9701047
)
@@ -973,6 +1050,7 @@ def random(
9731050
Provide a dist function that creates a PyTensor graph built from other
9741051
PyMC distributions. PyMC can automatically infer that the logp of this
9751052
variable corresponds to a shifted Exponential distribution.
1053+
A gufunc signature was also provided, which may be used by other routines.
9761054
9771055
.. code-block:: python
9781056
@@ -994,6 +1072,7 @@ def dist(
9941072
lam,
9951073
shift,
9961074
dist=dist,
1075+
signature="(),()->()",
9971076
observed=[-1, -1, 0],
9981077
)
9991078
@@ -1040,10 +1119,11 @@ def __new__(
10401119
random: Optional[Callable] = None,
10411120
logp: Optional[Callable] = None,
10421121
logcdf: Optional[Callable] = None,
1043-
moment: Optional[Callable] = None,
10441122
support_point: Optional[Callable] = None,
1045-
ndim_supp: int = 0,
1123+
# TODO: Deprecate ndim_supp / ndims_params in favor of signature?
1124+
ndim_supp: Optional[int] = None,
10461125
ndims_params: Optional[Sequence[int]] = None,
1126+
signature: Optional[str] = None,
10471127
dtype: str = "floatX",
10481128
**kwargs,
10491129
):
@@ -1057,6 +1137,7 @@ def __new__(
10571137
)
10581138
dist_params = cls.parse_dist_params(dist_params)
10591139
cls.check_valid_dist_random(dist, random, dist_params)
1140+
moment = kwargs.pop("moment", None)
10601141
if moment is not None:
10611142
warnings.warn(
10621143
"`moment` argument is deprecated. Use `support_point` instead.",
@@ -1073,6 +1154,8 @@ def __new__(
10731154
logcdf=logcdf,
10741155
support_point=support_point,
10751156
ndim_supp=ndim_supp,
1157+
ndims_params=ndims_params,
1158+
signature=signature,
10761159
**kwargs,
10771160
)
10781161
else:
@@ -1086,6 +1169,7 @@ def __new__(
10861169
support_point=support_point,
10871170
ndim_supp=ndim_supp,
10881171
ndims_params=ndims_params,
1172+
signature=signature,
10891173
dtype=dtype,
10901174
**kwargs,
10911175
)
@@ -1099,8 +1183,9 @@ def dist(
10991183
logp: Optional[Callable] = None,
11001184
logcdf: Optional[Callable] = None,
11011185
support_point: Optional[Callable] = None,
1102-
ndim_supp: int = 0,
1186+
ndim_supp: Optional[int] = None,
11031187
ndims_params: Optional[Sequence[int]] = None,
1188+
signature: Optional[str] = None,
11041189
dtype: str = "floatX",
11051190
**kwargs,
11061191
):
@@ -1114,6 +1199,8 @@ def dist(
11141199
logcdf=logcdf,
11151200
support_point=support_point,
11161201
ndim_supp=ndim_supp,
1202+
ndims_params=ndims_params,
1203+
signature=signature,
11171204
**kwargs,
11181205
)
11191206
else:
@@ -1125,6 +1212,7 @@ def dist(
11251212
support_point=support_point,
11261213
ndim_supp=ndim_supp,
11271214
ndims_params=ndims_params,
1215+
signature=signature,
11281216
dtype=dtype,
11291217
**kwargs,
11301218
)

pymc/distributions/mixture.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None):
296296
# Output mix_indexes rng update so that it can be updated in place
297297
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]
298298

299+
s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
300+
if len(components) == 1:
301+
comp_s = ",".join((*s, "w"))
302+
signature = f"(),(w),({comp_s})->({s})"
303+
else:
304+
comps_s = ",".join(f"({s})" for _ in components)
305+
signature = f"(),(w),{comps_s}->({s})"
299306
mix_op = MarginalMixtureRV(
300307
inputs=[mix_indexes_rng_, weights_, *components_],
301308
outputs=[mix_indexes_rng_next_, mix_out_],
302-
ndim_supp=components[0].owner.op.ndim_supp,
309+
signature=signature,
303310
)
304311

305312
# Create the actual MarginalMixture variable

0 commit comments

Comments
 (0)