Skip to content

Commit 04fe3cd

Browse files
committed
Remove MeasurableOpMixin
1 parent e415124 commit 04fe3cd

File tree

7 files changed

+21
-27
lines changed

7 files changed

+21
-27
lines changed

pymc/logprob/abstract.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __getattr__(name):
5353
f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.",
5454
FutureWarning,
5555
)
56-
return MeasurableOpMixin
56+
return MeasurableOp
5757

5858
raise AttributeError(f"module {__name__} has no attribute {name}")
5959

@@ -150,14 +150,7 @@ class MeasurableOp(abc.ABC):
150150
MeasurableOp.register(RandomVariable)
151151

152152

153-
class MeasurableOpMixin(MeasurableOp):
154-
"""MeasurableOp Mixin with a distinctive string representation"""
155-
156-
def __str__(self):
157-
return f"Measurable{super().__str__()}"
158-
159-
160-
class MeasurableElemwise(MeasurableOpMixin, Elemwise):
153+
class MeasurableElemwise(MeasurableOp, Elemwise):
161154
"""Base class for Measurable Elemwise variables"""
162155

163156
valid_scalar_types: tuple[MetaType, ...] = ()
@@ -169,3 +162,6 @@ def __init__(self, scalar_op, *args, **kwargs):
169162
f"Acceptable types are {self.valid_scalar_types}"
170163
)
171164
super().__init__(scalar_op, *args, **kwargs)
165+
166+
def __str__(self):
167+
return f"Measurable{super().__str__()}"

pymc/logprob/checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
from pytensor.tensor import TensorVariable
4343
from pytensor.tensor.shape import SpecifyShape
4444

45-
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper
45+
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
4646
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
4747
from pymc.logprob.utils import replace_rvs_by_values
4848

4949

50-
class MeasurableSpecifyShape(MeasurableOpMixin, SpecifyShape):
50+
class MeasurableSpecifyShape(MeasurableOp, SpecifyShape):
5151
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""
5252

5353

@@ -96,7 +96,7 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:
9696
)
9797

9898

99-
class MeasurableCheckAndRaise(MeasurableOpMixin, CheckAndRaise):
99+
class MeasurableCheckAndRaise(MeasurableOp, CheckAndRaise):
100100
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""
101101

102102

pymc/logprob/cumsum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
from pytensor.tensor import TensorVariable
4242
from pytensor.tensor.extra_ops import CumOp
4343

44-
from pymc.logprob.abstract import MeasurableOpMixin, _logprob, _logprob_helper
44+
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
4545
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
4646

4747

48-
class MeasurableCumsum(MeasurableOpMixin, CumOp):
48+
class MeasurableCumsum(MeasurableOp, CumOp):
4949
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
5050

5151

pymc/logprob/mixture.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
from pymc.logprob.abstract import (
6969
MeasurableElemwise,
7070
MeasurableOp,
71-
MeasurableOpMixin,
7271
_logprob,
7372
_logprob_helper,
7473
)
@@ -218,7 +217,7 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable:
218217
return fgraph.outputs[0]
219218

220219

221-
class MixtureRV(MeasurableOpMixin, Op):
220+
class MixtureRV(MeasurableOp, Op):
222221
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
223222

224223
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
@@ -455,7 +454,7 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
455454
)
456455

457456

458-
class MeasurableIfElse(MeasurableOpMixin, IfElse):
457+
class MeasurableIfElse(MeasurableOp, IfElse):
459458
"""Measurable subclass of IfElse operator."""
460459

461460

pymc/logprob/order.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from pymc.logprob.abstract import (
4747
MeasurableElemwise,
4848
MeasurableOp,
49-
MeasurableOpMixin,
5049
_logcdf_helper,
5150
_logprob,
5251
_logprob_helper,
@@ -56,11 +55,11 @@
5655
from pymc.pytensorf import constant_fold
5756

5857

59-
class MeasurableMax(MeasurableOpMixin, Max):
58+
class MeasurableMax(MeasurableOp, Max):
6059
"""A placeholder used to specify a log-likelihood for a max sub-graph."""
6160

6261

63-
class MeasurableMaxDiscrete(MeasurableOpMixin, Max):
62+
class MeasurableMaxDiscrete(MeasurableOp, Max):
6463
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""
6564

6665

pymc/logprob/scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from pytensor.tensor.variable import TensorVariable
5555
from pytensor.updates import OrderedUpdates
5656

57-
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob
57+
from pymc.logprob.abstract import MeasurableOp, _logprob
5858
from pymc.logprob.basic import conditional_logp
5959
from pymc.logprob.rewriting import (
6060
PreserveRVMappings,
@@ -66,11 +66,11 @@
6666
from pymc.logprob.utils import replace_rvs_by_values
6767

6868

69-
class MeasurableScan(MeasurableOpMixin, Scan):
69+
class MeasurableScan(MeasurableOp, Scan):
7070
"""A placeholder used to specify a log-likelihood for a scan sub-graph."""
7171

7272
def __str__(self):
73-
return f"Measurable({super().__str__()})"
73+
return f"Measurable{super().__str__()}"
7474

7575

7676
def convert_outer_out_to_in(

pymc/logprob/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
local_rv_size_lift,
5353
)
5454

55-
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper
55+
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
5656
from pymc.logprob.rewriting import (
5757
PreserveRVMappings,
5858
assume_measured_ir_outputs,
@@ -124,7 +124,7 @@ def naive_bcast_rv_lift(fgraph: FunctionGraph, node):
124124
return [bcasted_node.outputs[1]]
125125

126126

127-
class MeasurableMakeVector(MeasurableOpMixin, MakeVector):
127+
class MeasurableMakeVector(MeasurableOp, MakeVector):
128128
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
129129

130130

@@ -148,7 +148,7 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs):
148148
return pt.stack(logps)
149149

150150

151-
class MeasurableJoin(MeasurableOpMixin, Join):
151+
class MeasurableJoin(MeasurableOp, Join):
152152
"""A placeholder used to specify a log-likelihood for a join sub-graph."""
153153

154154

@@ -228,7 +228,7 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
228228
return [measurable_stack]
229229

230230

231-
class MeasurableDimShuffle(MeasurableOpMixin, DimShuffle):
231+
class MeasurableDimShuffle(MeasurableOp, DimShuffle):
232232
"""A placeholder used to specify a log-likelihood for a dimshuffle sub-graph."""
233233

234234
# Need to get the absolute path of `c_func_file`, otherwise it tries to

0 commit comments

Comments
 (0)