Skip to content

Commit 7c32d3c

Browse files
committed
Remove default transformed RV Ops
This functionality was not being used by PyMC. In theory, it could be used to cache one type of transformed RV Op instead of recreating it for every instance of the same RV + transform, resulting in a faster logp rewrite (?). We can always revisit this idea later, but for now it unnecessarily complicated our codebase.
1 parent acb1a54 commit 7c32d3c

File tree

2 files changed

+24
-157
lines changed

2 files changed

+24
-157
lines changed

pymc/logprob/transforms.py

Lines changed: 13 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import abc
3838

3939
from copy import copy
40-
from functools import partial, singledispatch
4140
from typing import Callable, Dict, List, Optional, Tuple, Union
4241

4342
import pytensor.tensor as at
@@ -69,21 +68,6 @@
6968
from pymc.logprob.utils import walk_model
7069

7170

72-
@singledispatch
73-
def _default_transformed_rv(
74-
op: Op,
75-
node: Node,
76-
) -> Optional[Apply]:
77-
"""Create a node for a transformed log-probability of a `MeasurableVariable`.
78-
79-
This function dispatches on the type of `op`. If you want to implement
80-
new transforms for a `MeasurableVariable`, register a function on this
81-
dispatcher.
82-
83-
"""
84-
return None
85-
86-
8771
class TransformedVariable(Op):
8872
"""A no-op that identifies a transform and its un-transformed input."""
8973

@@ -136,13 +120,6 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
136120
return at.log(at.abs(at.nlinalg.det(at.atleast_2d(jacobian(phi_inv, [value])[0]))))
137121

138122

139-
class DefaultTransformSentinel:
140-
pass
141-
142-
143-
DEFAULT_TRANSFORM = DefaultTransformSentinel()
144-
145-
146123
@node_rewriter(tracks=None)
147124
def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
148125
"""Apply transforms to value variables.
@@ -176,17 +153,12 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
176153

177154
if transform is None:
178155
return None
179-
elif transform is DEFAULT_TRANSFORM:
180-
trans_node = _default_transformed_rv(node.op, node)
181-
if trans_node is None:
182-
return None
183-
transform = trans_node.op.transform
184-
else:
185-
new_op = _create_transformed_rv_op(node.op, transform)
186-
# Create a new `Apply` node and outputs
187-
trans_node = node.clone()
188-
trans_node.op = new_op
189-
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name
156+
157+
new_op = _create_transformed_rv_op(node.op, transform)
158+
# Create a new `Apply` node and outputs
159+
trans_node = node.clone()
160+
trans_node.op = new_op
161+
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name
190162

191163
# We now assume that the old value variable represents the *transformed space*.
192164
# This means that we need to replace all instance of the old value variable
@@ -216,24 +188,22 @@ def on_attach(self, fgraph):
216188

217189

218190
class TransformValuesRewrite(GraphRewriter):
219-
r"""Transforms value variables according to a map and/or per-`RandomVariable` defaults."""
191+
r"""Transforms value variables according to a map."""
220192

221-
default_transform_rewrite = in2out(transform_values, ignore_newtrees=True)
193+
transform_rewrite = in2out(transform_values, ignore_newtrees=True)
222194

223195
def __init__(
224196
self,
225-
values_to_transforms: Dict[
226-
TensorVariable, Union[RVTransform, DefaultTransformSentinel, None]
227-
],
197+
values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]],
228198
):
229199
"""
230200
Parameters
231201
==========
232202
values_to_transforms
233203
Mapping between value variables and their transformations. Each
234-
value variable can be assigned one of `RVTransform`,
235-
``DEFAULT_TRANSFORM``, or ``None``. If a transform is not specified
236-
for a specific value variable it will not be transformed.
204+
value variable can be assigned one of `RVTransform`, or ``None``.
205+
If a transform is not specified for a specific value variable it will
206+
not be transformed.
237207
238208
"""
239209

@@ -244,7 +214,7 @@ def add_requirements(self, fgraph):
244214
fgraph.attach_feature(values_transforms_feature)
245215

246216
def apply(self, fgraph: FunctionGraph):
247-
return self.default_transform_rewrite.rewrite(fgraph)
217+
return self.transform_rewrite.rewrite(fgraph)
248218

249219

250220
class MeasurableTransform(MeasurableElemwise):
@@ -583,7 +553,6 @@ def _create_transformed_rv_op(
583553
rv_op: Op,
584554
transform: RVTransform,
585555
*,
586-
default: bool = False,
587556
cls_dict_extra: Optional[Dict] = None,
588557
) -> Op:
589558
"""Create a new transformed variable instance given a base `RandomVariable` `Op`.
@@ -600,8 +569,6 @@ def _create_transformed_rv_op(
600569
The `RandomVariable` for which we want to construct a `TransformedRV`.
601570
transform
602571
The `RVTransform` for `rv_op`.
603-
default
604-
If ``False`` do not make `transform` the default transform for `rv_op`.
605572
cls_dict_extra
606573
Additional class members to add to the constructed `TransformedRV`.
607574
@@ -642,85 +609,7 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
642609

643610
return logprob
644611

645-
transform_op = rv_op_type if default else new_op_type
646-
647-
@_default_transformed_rv.register(transform_op)
648-
def class_transformed_rv(op, node):
649-
new_op = new_op_type()
650-
res = new_op.make_node(*node.inputs)
651-
res.outputs[1].name = node.outputs[1].name
652-
return res
653-
654612
new_op = copy(rv_op)
655613
new_op.__class__ = new_op_type
656614

657615
return new_op
658-
659-
660-
create_default_transformed_rv_op = partial(_create_transformed_rv_op, default=True)
661-
662-
663-
TransformedUniformRV = create_default_transformed_rv_op(
664-
at.random.uniform,
665-
# inputs[3] = lower; inputs[4] = upper
666-
IntervalTransform(lambda *inputs: (inputs[3], inputs[4])),
667-
)
668-
TransformedParetoRV = create_default_transformed_rv_op(
669-
at.random.pareto,
670-
# inputs[3] = alpha
671-
IntervalTransform(lambda *inputs: (inputs[3], None)),
672-
)
673-
TransformedTriangularRV = create_default_transformed_rv_op(
674-
at.random.triangular,
675-
# inputs[3] = lower; inputs[5] = upper
676-
IntervalTransform(lambda *inputs: (inputs[3], inputs[5])),
677-
)
678-
TransformedHalfNormalRV = create_default_transformed_rv_op(
679-
at.random.halfnormal,
680-
# inputs[3] = loc
681-
IntervalTransform(lambda *inputs: (inputs[3], None)),
682-
)
683-
TransformedWaldRV = create_default_transformed_rv_op(
684-
at.random.wald,
685-
LogTransform(),
686-
)
687-
TransformedExponentialRV = create_default_transformed_rv_op(
688-
at.random.exponential,
689-
LogTransform(),
690-
)
691-
TransformedLognormalRV = create_default_transformed_rv_op(
692-
at.random.lognormal,
693-
LogTransform(),
694-
)
695-
TransformedHalfCauchyRV = create_default_transformed_rv_op(
696-
at.random.halfcauchy,
697-
LogTransform(),
698-
)
699-
TransformedGammaRV = create_default_transformed_rv_op(
700-
at.random.gamma,
701-
LogTransform(),
702-
)
703-
TransformedInvGammaRV = create_default_transformed_rv_op(
704-
at.random.invgamma,
705-
LogTransform(),
706-
)
707-
TransformedChiSquareRV = create_default_transformed_rv_op(
708-
at.random.chisquare,
709-
LogTransform(),
710-
)
711-
TransformedWeibullRV = create_default_transformed_rv_op(
712-
at.random.weibull,
713-
LogTransform(),
714-
)
715-
TransformedBetaRV = create_default_transformed_rv_op(
716-
at.random.beta,
717-
LogOddsTransform(),
718-
)
719-
TransformedVonMisesRV = create_default_transformed_rv_op(
720-
at.random.vonmises,
721-
CircularTransform(),
722-
)
723-
TransformedDirichletRV = create_default_transformed_rv_op(
724-
at.random.dirichlet,
725-
SimplexTransform(),
726-
)

pymc/tests/logprob/test_transforms.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545
from pytensor.graph.basic import equal_computations
4646
from pytensor.graph.fg import FunctionGraph
4747

48+
from pymc.distributions.transforms import _default_transform, log, logodds
4849
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
4950
from pymc.logprob.transforms import (
50-
DEFAULT_TRANSFORM,
5151
ChainedTransform,
5252
ExpTransform,
5353
IntervalTransform,
@@ -58,7 +58,6 @@
5858
ScaleTransform,
5959
TransformValuesMapping,
6060
TransformValuesRewrite,
61-
_default_transformed_rv,
6261
transformed_variable,
6362
)
6463
from pymc.tests.helpers import assert_no_rvs
@@ -216,6 +215,8 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
216215
elsewhere in the graph (i.e. in ``b``), by comparing the graph for the
217216
transformed log-probability with the SciPy-derived log-probability--using a
218217
numeric approximation to the Jacobian term.
218+
219+
TODO: This test is rather redundant with those in tess/distributions/test_transform.py
219220
"""
220221

221222
a = at_dist(*dist_params, size=size)
@@ -228,16 +229,14 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
228229
b_value_var = b.clone()
229230
b_value_var.name = "b_value"
230231

231-
transform_rewrite = TransformValuesRewrite({a_value_var: DEFAULT_TRANSFORM})
232+
transform = _default_transform(a.owner.op, a)
233+
transform_rewrite = TransformValuesRewrite({a_value_var: transform})
232234
res = joint_logprob({a: a_value_var, b: b_value_var}, extra_rewrites=transform_rewrite)
233235

234236
test_val_rng = np.random.RandomState(3238)
235237

236238
logp_vals_fn = pytensor.function([a_value_var, b_value_var], res)
237239

238-
a_trans_op = _default_transformed_rv(a.owner.op, a.owner).op
239-
transform = a_trans_op.transform
240-
241240
a_forward_fn = pytensor.function([a_value_var], transform.forward(a_value_var, *a.owner.inputs))
242241
a_backward_fn = pytensor.function(
243242
[a_value_var], transform.backward(a_value_var, *a.owner.inputs)
@@ -305,7 +304,7 @@ def test_simple_transformed_logprob_nojac(use_jacobian):
305304
x_vv = X_rv.clone()
306305
x_vv.name = "x"
307306

308-
transform_rewrite = TransformValuesRewrite({x_vv: DEFAULT_TRANSFORM})
307+
transform_rewrite = TransformValuesRewrite({x_vv: log})
309308
tr_logp = joint_logprob(
310309
{X_rv: x_vv}, extra_rewrites=transform_rewrite, use_jacobian=use_jacobian
311310
)
@@ -359,9 +358,9 @@ def test_hierarchical_uniform_transform():
359358

360359
transform_rewrite = TransformValuesRewrite(
361360
{
362-
lower: DEFAULT_TRANSFORM,
363-
upper: DEFAULT_TRANSFORM,
364-
x: DEFAULT_TRANSFORM,
361+
lower: _default_transform(lower_rv.owner.op, lower_rv),
362+
upper: _default_transform(upper_rv.owner.op, upper_rv),
363+
x: _default_transform(x_rv.owner.op, x_rv),
365364
}
366365
)
367366
logp = joint_logprob(
@@ -425,28 +424,7 @@ def test_default_transform_multiout():
425424
x_rv = at.random.normal(0, sd, name="x")
426425
x = x_rv.clone()
427426

428-
transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
429-
430-
logp = joint_logprob(
431-
{x_rv: x},
432-
extra_rewrites=transform_rewrite,
433-
)
434-
435-
assert np.isclose(
436-
logp.eval({x: 1}),
437-
sp.stats.norm(0, 1).logpdf(1),
438-
)
439-
440-
441-
def test_nonexistent_default_transform():
442-
"""
443-
Test that setting `DEFAULT_TRANSFORM` to a variable that has no default
444-
transform does not fail
445-
"""
446-
x_rv = at.random.normal(name="x")
447-
x = x_rv.clone()
448-
449-
transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
427+
transform_rewrite = TransformValuesRewrite({x: None})
450428

451429
logp = joint_logprob(
452430
{x_rv: x},
@@ -480,7 +458,7 @@ def test_original_values_output_dict():
480458
p_rv = at.random.beta(1, 1, name="p")
481459
p_vv = p_rv.clone()
482460

483-
tr = TransformValuesRewrite({p_vv: DEFAULT_TRANSFORM})
461+
tr = TransformValuesRewrite({p_vv: logodds})
484462
logp_dict = factorized_joint_logprob({p_rv: p_vv}, extra_rewrites=tr)
485463

486464
assert p_vv in logp_dict

0 commit comments

Comments
 (0)