49
49
from pytensor .tensor .random .op import RandomVariable
50
50
from pytensor .tensor .variable import TensorVariable
51
51
52
+ import pymc as pm
53
+
52
54
from pymc .logprob .abstract import (
53
55
MeasurableVariable ,
54
56
_logcdf_helper ,
@@ -67,7 +69,7 @@ class MeasurableMax(Max):
67
69
68
70
69
71
class MeasurableMaxDiscrete (Max ):
70
- """A placeholder used to specify a log-likelihood for a cmax sub-graph. """
72
+ """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables """
71
73
72
74
73
75
MeasurableVariable .register (MeasurableMaxDiscrete )
@@ -105,14 +107,14 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
105
107
if axis != base_var_dims :
106
108
return None
107
109
108
- # logprob for discrete distribution
109
- if isinstance (base_var .owner .op , RandomVariable ) and base_var .owner .op .dtype .startswith ("int" ):
110
- measurable_max = MeasurableMaxDiscrete (list (axis ))
111
- max_rv_node = measurable_max .make_node (base_var )
112
- max_rv = max_rv_node .outputs
110
+ # distinguish measurable discrete and continuous (because logprob is different)
111
+ if base_var .owner .op .dtype .startswith ("int" ):
112
+ if isinstance (base_var .owner .op , RandomVariable ):
113
+ measurable_max = MeasurableMaxDiscrete (list (axis ))
114
+ max_rv_node = measurable_max .make_node (base_var )
115
+ max_rv = max_rv_node .outputs
113
116
114
- return max_rv
115
- # logprob for continuous distribution
117
+ return max_rv
116
118
else :
117
119
measurable_max = MeasurableMax (list (axis ))
118
120
max_rv_node = measurable_max .make_node (base_var )
@@ -148,17 +150,17 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
148
150
r"""Compute the log-likelihood graph for the `Max` operation.
149
151
150
152
The formula that we use here is :
151
- \ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
152
- where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively.
153
+ .. math::
154
+ \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
155
+ where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
153
156
"""
154
157
(value ,) = values
155
- logprob = _logprob_helper (base_rv , value )
156
158
logcdf = _logcdf_helper (base_rv , value )
157
159
logcdf_prev = _logcdf_helper (base_rv , value - 1 )
158
160
159
- n = base_rv .size
161
+ [ n ] = constant_fold ([ base_rv .size ])
160
162
161
- logprob = pt . log (( pt . exp ( logcdf )) ** n - ( pt . exp ( logcdf_prev )) ** n )
163
+ logprob = pm . math . logdiffexp ( n * logcdf , n * logcdf_prev )
162
164
163
165
return logprob
164
166
0 commit comments