33
33
from pytensor .graph .utils import MetaType
34
34
from pytensor .scan .op import Scan
35
35
from pytensor .tensor .basic import as_tensor_variable
36
+ from pytensor .tensor .blockwise import safe_signature
36
37
from pytensor .tensor .random .op import RandomVariable
37
38
from pytensor .tensor .random .rewriting import local_subtensor_rv_lift
38
39
from pytensor .tensor .random .type import RandomGeneratorType , RandomType
39
40
from pytensor .tensor .random .utils import normalize_size_param
40
41
from pytensor .tensor .rewriting .shape import ShapeFeature
42
+ from pytensor .tensor .utils import _parse_gufunc_signature
41
43
from pytensor .tensor .variable import TensorVariable
42
44
from typing_extensions import TypeAlias
43
45
@@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph):
261
263
(0 for scalar, 1 for vector, ...)
262
264
"""
263
265
266
+ ndims_params : Optional [Sequence [int ]] = None
267
+ """Number of core dimensions of the distribution's parameters."""
268
+
269
+ signature : str = None
270
+ """Numpy-like vectorized signature of the distribution."""
271
+
264
272
inline_logprob : bool = False
265
273
"""Specifies whether the logprob function is derived automatically by introspection
266
274
of the inner graph.
@@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph):
271
279
_print_name : tuple [str , str ] = ("Unknown" , "\\ operatorname{Unknown}" )
272
280
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""
273
281
274
- def __init__ (self , * args , ndim_supp , ** kwargs ):
275
- """Initialitze a SymbolicRandomVariable class."""
276
- self .ndim_supp = ndim_supp
282
+ def __init__ (
283
+ self ,
284
+ * args ,
285
+ ** kwargs ,
286
+ ):
287
+ """Initialize a SymbolicRandomVariable class."""
288
+ if self .signature is None :
289
+ self .signature = kwargs .get ("signature" , None )
290
+
291
+ if self .signature is not None :
292
+ inputs_sig , outputs_sig = _parse_gufunc_signature (self .signature )
293
+ self .ndims_params = [len (sig ) for sig in inputs_sig ]
294
+ self .ndim_supp = max (len (out_sig ) for out_sig in outputs_sig )
295
+
296
+ if self .ndim_supp is None :
297
+ self .ndim_supp = kwargs .get ("ndim_supp" , None )
298
+ if self .ndim_supp is None :
299
+ raise ValueError ("ndim_supp or gufunc_signature must be provided" )
300
+
277
301
kwargs .setdefault ("inline" , True )
278
302
super ().__init__ (* args , ** kwargs )
279
303
@@ -286,6 +310,11 @@ def update(self, node: Node):
286
310
"""
287
311
return {}
288
312
313
+ def batch_ndim (self , node : Node ) -> int :
314
+ """Number of dimensions of the distribution's batch shape."""
315
+ out_ndim = max (getattr (out .type , "ndim" , 0 ) for out in node .outputs )
316
+ return out_ndim - self .ndim_supp
317
+
289
318
290
319
class Distribution (metaclass = DistributionMeta ):
291
320
"""Statistical distribution"""
@@ -558,23 +587,29 @@ def dist(
558
587
logcdf : Optional [Callable ] = None ,
559
588
random : Optional [Callable ] = None ,
560
589
support_point : Optional [Callable ] = None ,
561
- ndim_supp : int = 0 ,
590
+ ndim_supp : Optional [ int ] = None ,
562
591
ndims_params : Optional [Sequence [int ]] = None ,
592
+ signature : Optional [str ] = None ,
563
593
dtype : str = "floatX" ,
564
594
class_name : str = "CustomDist" ,
565
595
** kwargs ,
566
596
):
597
+ if ndim_supp is None or ndims_params is None :
598
+ if signature is None :
599
+ ndim_supp = 0
600
+ ndims_params = [0 ] * len (dist_params )
601
+ else :
602
+ inputs , outputs = _parse_gufunc_signature (signature )
603
+ ndim_supp = max (len (out ) for out in outputs )
604
+ ndims_params = [len (inp ) for inp in inputs ]
605
+
567
606
if ndim_supp > 0 :
568
607
raise NotImplementedError (
569
608
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
570
609
)
571
610
572
611
dist_params = [as_tensor_variable (param ) for param in dist_params ]
573
612
574
- # Assume scalar ndims_params
575
- if ndims_params is None :
576
- ndims_params = [0 ] * len (dist_params )
577
-
578
613
if logp is None :
579
614
logp = default_not_implemented (class_name , "logp" )
580
615
@@ -614,7 +649,7 @@ def rv_op(
614
649
random : Optional [Callable ],
615
650
support_point : Optional [Callable ],
616
651
ndim_supp : int ,
617
- ndims_params : Optional [ Sequence [int ] ],
652
+ ndims_params : Sequence [int ],
618
653
dtype : str ,
619
654
class_name : str ,
620
655
** kwargs ,
@@ -702,7 +737,9 @@ def dist(
702
737
logp : Optional [Callable ] = None ,
703
738
logcdf : Optional [Callable ] = None ,
704
739
support_point : Optional [Callable ] = None ,
705
- ndim_supp : int = 0 ,
740
+ ndim_supp : Optional [int ] = None ,
741
+ ndims_params : Optional [Sequence [int ]] = None ,
742
+ signature : Optional [str ] = None ,
706
743
dtype : str = "floatX" ,
707
744
class_name : str = "CustomDist" ,
708
745
** kwargs ,
@@ -712,14 +749,24 @@ def dist(
712
749
if logcdf is None :
713
750
logcdf = default_not_implemented (class_name , "logcdf" )
714
751
752
+ if signature is None :
753
+ if ndim_supp is None :
754
+ ndim_supp = 0
755
+ if ndims_params is None :
756
+ ndims_params = [0 ] * len (dist_params )
757
+ signature = safe_signature (
758
+ core_inputs = [pt .tensor (shape = (None ,) * ndim_param ) for ndim_param in ndims_params ],
759
+ core_outputs = [pt .tensor (shape = (None ,) * ndim_supp )],
760
+ )
761
+
715
762
return super ().dist (
716
763
dist_params ,
717
764
class_name = class_name ,
718
765
logp = logp ,
719
766
logcdf = logcdf ,
720
767
dist = dist ,
721
768
support_point = support_point ,
722
- ndim_supp = ndim_supp ,
769
+ signature = signature ,
723
770
** kwargs ,
724
771
)
725
772
@@ -732,7 +779,7 @@ def rv_op(
732
779
logcdf : Optional [Callable ],
733
780
support_point : Optional [Callable ],
734
781
size = None ,
735
- ndim_supp : int ,
782
+ signature : str ,
736
783
class_name : str ,
737
784
):
738
785
size = normalize_size_param (size )
@@ -745,6 +792,10 @@ def rv_op(
745
792
dummy_params = [dummy_size_param , * dummy_dist_params ]
746
793
dummy_updates_dict = collect_default_updates (inputs = dummy_params , outputs = (dummy_rv ,))
747
794
795
+ signature = cls ._infer_final_signature (
796
+ signature , len (dummy_params ), len (dummy_updates_dict )
797
+ )
798
+
748
799
rv_type = type (
749
800
class_name ,
750
801
(CustomSymbolicDistRV ,),
@@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
802
853
new_rv_op = rv_type (
803
854
inputs = dummy_params ,
804
855
outputs = [* dummy_updates_dict .values (), dummy_rv ],
805
- ndim_supp = ndim_supp ,
856
+ signature = signature ,
806
857
)
807
858
new_rv = new_rv_op (new_size , * dist_params )
808
859
@@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
811
862
rv_op = rv_type (
812
863
inputs = dummy_params ,
813
864
outputs = [* dummy_updates_dict .values (), dummy_rv ],
814
- ndim_supp = ndim_supp ,
865
+ signature = signature ,
815
866
)
816
867
return rv_op (size , * dist_params )
817
868
869
+ @staticmethod
870
+ def _infer_final_signature (signature : str , n_inputs , n_updates ) -> str :
871
+ """Add size and updates to user provided gufunc signature if they are missing."""
872
+ input_sig , output_sig = signature .split ("->" )
873
+ # Numpy parser does not accept (constant) functions without inputs like "->()"
874
+ # We work around as this makes sense for distributions like Flat that have no inputs
875
+ if input_sig .strip () == "" :
876
+ inputs = ()
877
+ _ , outputs = _parse_gufunc_signature ("()" + signature )
878
+ else :
879
+ inputs , outputs = _parse_gufunc_signature (signature )
880
+ if len (inputs ) == n_inputs - 1 :
881
+ # Assume size is missing
882
+ input_sig = ("()," if input_sig else "()" ) + input_sig
883
+ if len (outputs ) == 1 :
884
+ # Assume updates are missing
885
+ output_sig = "()," * n_updates + output_sig
886
+ signature = "->" .join ((input_sig , output_sig ))
887
+ return signature
888
+
818
889
819
890
class CustomDist :
820
891
"""A helper class to create custom distributions
@@ -828,12 +899,12 @@ class CustomDist:
828
899
when not provided by the user.
829
900
830
901
Alternatively, a user can provide a `random` function that returns numerical
831
- draws (e.g., via NumPy routines), and a `logp` function that must return an
832
- Python graph that represents the logp graph when evaluated. This is used for
902
+ draws (e.g., via NumPy routines), and a `logp` function that must return a
903
+ PyTensor graph that represents the logp graph when evaluated. This is used for
833
904
mcmc sampling.
834
905
835
906
Additionally, a user can provide a `logcdf` and `support_point` functions that must return
836
- an PyTensor graph that computes those quantities. These may be used by other PyMC
907
+ PyTensor graphs that computes those quantities. These may be used by other PyMC
837
908
routines.
838
909
839
910
Parameters
@@ -894,14 +965,18 @@ class CustomDist:
894
965
distribution parameters, in the same order as they were supplied when the
895
966
CustomDist was created. If ``None``, a default ``support_point`` function will be
896
967
assigned that will always return 0, or an array of zeros.
897
- ndim_supp : int
898
- The number of dimensions in the support of the distribution. Defaults to assuming
899
- a scalar distribution, i.e. ``ndim_supp = 0``.
968
+ ndim_supp : Optional[int]
969
+ The number of dimensions in the support of the distribution.
970
+ Inferred from signature, if provided. Defaults to assuming
971
+ a scalar distribution, i.e. ``ndim_supp = 0``
900
972
ndims_params : Optional[Sequence[int]]
901
973
The list of number of dimensions in the support of each of the distribution's
902
- parameters. If ``None``, it is assumed that all parameters are scalars, hence
903
- the number of dimensions of their support will be 0. This is not needed if an
904
- PyTensor dist function is provided.
974
+ parameters. Inferred from signature, if provided. Defaults to assuming
975
+ all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
976
+ signature : Optional[str]
977
+ A numpy vectorize-like signature that indicates the number and core dimensionality
978
+ of the input parameters and sample outputs of the CustomDist.
979
+ When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
905
980
dtype : str
906
981
The dtype of the distribution. All draws and observations passed into the
907
982
distribution will be cast onto this dtype. This is not needed if an PyTensor
@@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable:
939
1014
940
1015
Provide a random function that return numerical draws. This allows one to use a
941
1016
CustomDist in prior and posterior predictive sampling.
1017
+ A gufunc signature was also provided, which may be used by other routines.
942
1018
943
1019
.. code-block:: python
944
1020
@@ -965,6 +1041,7 @@ def random(
965
1041
mu,
966
1042
logp=logp,
967
1043
random=random,
1044
+ signature="()->()",
968
1045
observed=np.random.randn(100, 3),
969
1046
size=(100, 3),
970
1047
)
@@ -973,6 +1050,7 @@ def random(
973
1050
Provide a dist function that creates a PyTensor graph built from other
974
1051
PyMC distributions. PyMC can automatically infer that the logp of this
975
1052
variable corresponds to a shifted Exponential distribution.
1053
+ A gufunc signature was also provided, which may be used by other routines.
976
1054
977
1055
.. code-block:: python
978
1056
@@ -994,6 +1072,7 @@ def dist(
994
1072
lam,
995
1073
shift,
996
1074
dist=dist,
1075
+ signature="(),()->()",
997
1076
observed=[-1, -1, 0],
998
1077
)
999
1078
@@ -1040,10 +1119,11 @@ def __new__(
1040
1119
random : Optional [Callable ] = None ,
1041
1120
logp : Optional [Callable ] = None ,
1042
1121
logcdf : Optional [Callable ] = None ,
1043
- moment : Optional [Callable ] = None ,
1044
1122
support_point : Optional [Callable ] = None ,
1045
- ndim_supp : int = 0 ,
1123
+ # TODO: Deprecate ndim_supp / ndims_params in favor of signature?
1124
+ ndim_supp : Optional [int ] = None ,
1046
1125
ndims_params : Optional [Sequence [int ]] = None ,
1126
+ signature : Optional [str ] = None ,
1047
1127
dtype : str = "floatX" ,
1048
1128
** kwargs ,
1049
1129
):
@@ -1057,6 +1137,7 @@ def __new__(
1057
1137
)
1058
1138
dist_params = cls .parse_dist_params (dist_params )
1059
1139
cls .check_valid_dist_random (dist , random , dist_params )
1140
+ moment = kwargs .pop ("moment" , None )
1060
1141
if moment is not None :
1061
1142
warnings .warn (
1062
1143
"`moment` argument is deprecated. Use `support_point` instead." ,
@@ -1073,6 +1154,8 @@ def __new__(
1073
1154
logcdf = logcdf ,
1074
1155
support_point = support_point ,
1075
1156
ndim_supp = ndim_supp ,
1157
+ ndims_params = ndims_params ,
1158
+ signature = signature ,
1076
1159
** kwargs ,
1077
1160
)
1078
1161
else :
@@ -1086,6 +1169,7 @@ def __new__(
1086
1169
support_point = support_point ,
1087
1170
ndim_supp = ndim_supp ,
1088
1171
ndims_params = ndims_params ,
1172
+ signature = signature ,
1089
1173
dtype = dtype ,
1090
1174
** kwargs ,
1091
1175
)
@@ -1099,8 +1183,9 @@ def dist(
1099
1183
logp : Optional [Callable ] = None ,
1100
1184
logcdf : Optional [Callable ] = None ,
1101
1185
support_point : Optional [Callable ] = None ,
1102
- ndim_supp : int = 0 ,
1186
+ ndim_supp : Optional [ int ] = None ,
1103
1187
ndims_params : Optional [Sequence [int ]] = None ,
1188
+ signature : Optional [str ] = None ,
1104
1189
dtype : str = "floatX" ,
1105
1190
** kwargs ,
1106
1191
):
@@ -1114,6 +1199,8 @@ def dist(
1114
1199
logcdf = logcdf ,
1115
1200
support_point = support_point ,
1116
1201
ndim_supp = ndim_supp ,
1202
+ ndims_params = ndims_params ,
1203
+ signature = signature ,
1117
1204
** kwargs ,
1118
1205
)
1119
1206
else :
@@ -1125,6 +1212,7 @@ def dist(
1125
1212
support_point = support_point ,
1126
1213
ndim_supp = ndim_supp ,
1127
1214
ndims_params = ndims_params ,
1215
+ signature = signature ,
1128
1216
dtype = dtype ,
1129
1217
** kwargs ,
1130
1218
)
0 commit comments