Skip to content

Commit 43dc4be

Browse files
committed
Rename MeasurableVariable to MeasurableOp
Also: * Introduce MeasurableOpMixin for string representation * Subclass directly instead of registering manually
1 parent a3e2261 commit 43dc4be

19 files changed

+78
-121
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
rv_size_is_none,
5151
shape_from_dims,
5252
)
53-
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
53+
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
5454
from pymc.logprob.basic import logp
5555
from pymc.logprob.rewriting import logprob_rewrites_db
5656
from pymc.printing import str_for_dist
@@ -228,7 +228,7 @@ def __get__(self, instance, type_):
228228
return descr_get(instance, type_)
229229

230230

231-
class SymbolicRandomVariable(OpFromGraph):
231+
class SymbolicRandomVariable(MeasurableOp, OpFromGraph):
232232
"""Symbolic Random Variable
233233
234234
This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic
@@ -624,10 +624,6 @@ def dist(
624624
return rv_out
625625

626626

627-
# Let PyMC know that the SymbolicRandomVariable has a logprob.
628-
MeasurableVariable.register(SymbolicRandomVariable)
629-
630-
631627
@node_rewriter([SymbolicRandomVariable])
632628
def inline_symbolic_random_variable(fgraph, node):
633629
"""

pymc/logprob/abstract.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
# SOFTWARE.
3636

3737
import abc
38+
import warnings
3839

3940
from collections.abc import Sequence
4041
from functools import singledispatch
@@ -46,6 +47,17 @@
4647
from pytensor.tensor.random.op import RandomVariable
4748

4849

50+
def __getattr__(name):
51+
if name == "MeasurableVariable":
52+
warnings.warn(
53+
f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.",
54+
FutureWarning,
55+
)
56+
return MeasurableOpMixin
57+
58+
raise AttributeError(f"module {__name__} has no attribute {name}")
59+
60+
4961
@singledispatch
5062
def _logprob(
5163
op: Op,
@@ -131,14 +143,21 @@ def _icdf_helper(rv, value, **kwargs):
131143
return rv_icdf
132144

133145

134-
class MeasurableVariable(abc.ABC):
135-
"""A variable that can be assigned a measure/log-probability"""
146+
class MeasurableOp(abc.ABC):
147+
"""An operation whose outputs can be assigned a measure/log-probability"""
148+
136149

150+
MeasurableOp.register(RandomVariable)
137151

138-
MeasurableVariable.register(RandomVariable)
139152

153+
class MeasurableOpMixin(MeasurableOp):
154+
"""MeasurableOp Mixin with a distinctive string representation"""
140155

141-
class MeasurableElemwise(Elemwise):
156+
def __str__(self):
157+
return f"Measurable{super().__str__()}"
158+
159+
160+
class MeasurableElemwise(MeasurableOpMixin, Elemwise):
142161
"""Base class for Measurable Elemwise variables"""
143162

144163
valid_scalar_types: tuple[MetaType, ...] = ()
@@ -150,6 +169,3 @@ def __init__(self, scalar_op, *args, **kwargs):
150169
f"Acceptable types are {self.valid_scalar_types}"
151170
)
152171
super().__init__(scalar_op, *args, **kwargs)
153-
154-
155-
MeasurableVariable.register(MeasurableElemwise)

pymc/logprob/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from pytensor.tensor.variable import TensorVariable
5757

5858
from pymc.logprob.abstract import (
59-
MeasurableVariable,
59+
MeasurableOp,
6060
_icdf_helper,
6161
_logcdf_helper,
6262
_logprob,
@@ -522,7 +522,7 @@ def conditional_logp(
522522
while q:
523523
node = q.popleft()
524524

525-
if not isinstance(node.op, MeasurableVariable):
525+
if not isinstance(node.op, MeasurableOp):
526526
continue
527527

528528
q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]

pymc/logprob/checks.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,15 @@
4242
from pytensor.tensor import TensorVariable
4343
from pytensor.tensor.shape import SpecifyShape
4444

45-
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
45+
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper
4646
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
4747
from pymc.logprob.utils import replace_rvs_by_values
4848

4949

50-
class MeasurableSpecifyShape(SpecifyShape):
50+
class MeasurableSpecifyShape(MeasurableOpMixin, SpecifyShape):
5151
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""
5252

5353

54-
MeasurableVariable.register(MeasurableSpecifyShape)
55-
56-
5754
@_logprob.register(MeasurableSpecifyShape)
5855
def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
5956
(value,) = values
@@ -80,7 +77,7 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:
8077

8178
if not (
8279
base_rv.owner
83-
and isinstance(base_rv.owner.op, MeasurableVariable)
80+
and isinstance(base_rv.owner.op, MeasurableOp)
8481
and base_rv not in rv_map_feature.rv_values
8582
):
8683
return None # pragma: no cover
@@ -99,13 +96,10 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:
9996
)
10097

10198

102-
class MeasurableCheckAndRaise(CheckAndRaise):
99+
class MeasurableCheckAndRaise(MeasurableOpMixin, CheckAndRaise):
103100
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""
104101

105102

106-
MeasurableVariable.register(MeasurableCheckAndRaise)
107-
108-
109103
@_logprob.register(MeasurableCheckAndRaise)
110104
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
111105
(value,) = values

pymc/logprob/cumsum.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,14 @@
4141
from pytensor.tensor import TensorVariable
4242
from pytensor.tensor.extra_ops import CumOp
4343

44-
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
44+
from pymc.logprob.abstract import MeasurableOpMixin, _logprob, _logprob_helper
4545
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
4646

4747

48-
class MeasurableCumsum(CumOp):
48+
class MeasurableCumsum(MeasurableOpMixin, CumOp):
4949
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
5050

5151

52-
MeasurableVariable.register(MeasurableCumsum)
53-
54-
5552
@_logprob.register(MeasurableCumsum)
5653
def logprob_cumsum(op, values, base_rv, **kwargs):
5754
"""Compute the log-likelihood graph for a `Cumsum`."""

pymc/logprob/mixture.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767

6868
from pymc.logprob.abstract import (
6969
MeasurableElemwise,
70-
MeasurableVariable,
70+
MeasurableOp,
71+
MeasurableOpMixin,
7172
_logprob,
7273
_logprob_helper,
7374
)
@@ -217,7 +218,7 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable:
217218
return fgraph.outputs[0]
218219

219220

220-
class MixtureRV(Op):
221+
class MixtureRV(MeasurableOpMixin, Op):
221222
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
222223

223224
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
@@ -235,9 +236,6 @@ def perform(self, node, inputs, outputs):
235236
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover
236237

237238

238-
MeasurableVariable.register(MixtureRV)
239-
240-
241239
def get_stack_mixture_vars(
242240
node: Apply,
243241
) -> tuple[list[TensorVariable] | None, int | None]:
@@ -457,13 +455,10 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
457455
)
458456

459457

460-
class MeasurableIfElse(IfElse):
458+
class MeasurableIfElse(MeasurableOpMixin, IfElse):
461459
"""Measurable subclass of IfElse operator."""
462460

463461

464-
MeasurableVariable.register(MeasurableIfElse)
465-
466-
467462
@node_rewriter([IfElse])
468463
def useless_ifelse_outputs(fgraph, node):
469464
"""Remove outputs that are shared across the IfElse branches."""
@@ -512,7 +507,7 @@ def find_measurable_ifelse_mixture(fgraph, node):
512507
base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs)
513508
if len(base_rvs) != op.n_outs * 2:
514509
return None
515-
if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs):
510+
if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_rvs):
516511
return None
517512

518513
return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs

pymc/logprob/order.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytensor.tensor.variable import TensorVariable
4949

5050
from pymc.logprob.abstract import (
51-
MeasurableVariable,
51+
MeasurableOpMixin,
5252
_logcdf_helper,
5353
_logprob,
5454
_logprob_helper,
@@ -59,20 +59,14 @@
5959
from pymc.pytensorf import constant_fold
6060

6161

62-
class MeasurableMax(Max):
62+
class MeasurableMax(MeasurableOpMixin, Max):
6363
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
6464

6565

66-
MeasurableVariable.register(MeasurableMax)
67-
68-
69-
class MeasurableMaxDiscrete(Max):
66+
class MeasurableMaxDiscrete(MeasurableOpMixin, Max):
7067
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""
7168

7269

73-
MeasurableVariable.register(MeasurableMaxDiscrete)
74-
75-
7670
@node_rewriter([Max])
7771
def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
7872
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
@@ -162,21 +156,15 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
162156
return logprob
163157

164158

165-
class MeasurableMaxNeg(Max):
159+
class MeasurableMaxNeg(MeasurableOpMixin, Max):
166160
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
167161
This shows up in the graph of min, which is (neg(max(neg(x)))."""
168162

169163

170-
MeasurableVariable.register(MeasurableMaxNeg)
171-
172-
173-
class MeasurableDiscreteMaxNeg(Max):
164+
class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
174165
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
175166

176167

177-
MeasurableVariable.register(MeasurableDiscreteMaxNeg)
178-
179-
180168
@node_rewriter(tracks=[Max])
181169
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
182170
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

pymc/logprob/rewriting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
)
8282
from pytensor.tensor.variable import TensorVariable
8383

84-
from pymc.logprob.abstract import MeasurableVariable
84+
from pymc.logprob.abstract import MeasurableOp
8585
from pymc.logprob.utils import DiracDelta
8686

8787
inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
@@ -139,7 +139,7 @@ def apply(self, fgraph):
139139
continue
140140
# This is where we filter only those nodes we care about:
141141
# Nodes that have variables that we want to measure and are not yet measurable
142-
if isinstance(node.op, MeasurableVariable):
142+
if isinstance(node.op, MeasurableOp):
143143
continue
144144
if not any(out in rv_map_feature.needs_measuring for out in node.outputs):
145145
continue
@@ -155,7 +155,7 @@ def apply(self, fgraph):
155155
node_rewriter, "__name__", ""
156156
)
157157
# If we converted to a MeasurableVariable we're done here!
158-
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableVariable):
158+
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableOp):
159159
# go to next node
160160
break
161161

@@ -274,7 +274,7 @@ def request_measurable(self, vars: Sequence[Variable]) -> list[Variable]:
274274
# Input vars or valued vars can't be measured for derived expressions
275275
if not var.owner or var in self.rv_values:
276276
continue
277-
if isinstance(var.owner.op, MeasurableVariable):
277+
if isinstance(var.owner.op, MeasurableOp):
278278
measurable.append(var)
279279
else:
280280
self.needs_measuring.add(var)

pymc/logprob/scan.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from pytensor.tensor.variable import TensorVariable
5555
from pytensor.updates import OrderedUpdates
5656

57-
from pymc.logprob.abstract import MeasurableVariable, _logprob
57+
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob
5858
from pymc.logprob.basic import conditional_logp
5959
from pymc.logprob.rewriting import (
6060
PreserveRVMappings,
@@ -66,16 +66,13 @@
6666
from pymc.logprob.utils import replace_rvs_by_values
6767

6868

69-
class MeasurableScan(Scan):
69+
class MeasurableScan(MeasurableOpMixin, Scan):
7070
"""A placeholder used to specify a log-likelihood for a scan sub-graph."""
7171

7272
def __str__(self):
7373
return f"Measurable({super().__str__()})"
7474

7575

76-
MeasurableVariable.register(MeasurableScan)
77-
78-
7976
def convert_outer_out_to_in(
8077
input_scan_args: ScanArgs,
8178
outer_out_vars: Iterable[TensorVariable],
@@ -288,7 +285,7 @@ def get_random_outer_outputs(
288285
io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :]
289286
inner_out_type = f"inner_out_{io_type}"
290287
io_var = getattr(scan_args, inner_out_type)[oo_info.index]
291-
if io_var.owner and isinstance(io_var.owner.op, MeasurableVariable):
288+
if io_var.owner and isinstance(io_var.owner.op, MeasurableOp):
292289
rv_vars.append((n, oo_var, io_var))
293290
return rv_vars
294291

0 commit comments

Comments
 (0)