Skip to content

Commit a3e2261

Browse files
committed
Do not instantiate all transforms in find_measurable_transforms
1 parent d4714d2 commit a3e2261

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

pymc/logprob/transforms.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -463,21 +463,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
463463
transform_inputs: tuple[TensorVariable, ...] = (measurable_input,)
464464
transform: Transform
465465

466-
transform_dict = {
467-
Exp: ExpTransform(),
468-
Log: LogTransform(),
469-
Abs: AbsTransform(),
470-
Sinh: SinhTransform(),
471-
Cosh: CoshTransform(),
472-
Tanh: TanhTransform(),
473-
ArcSinh: ArcsinhTransform(),
474-
ArcCosh: ArccoshTransform(),
475-
ArcTanh: ArctanhTransform(),
476-
Erf: ErfTransform(),
477-
Erfc: ErfcTransform(),
478-
Erfcx: ErfcxTransform(),
479-
}
480-
transform = transform_dict.get(type(scalar_op), None)
481466
if isinstance(scalar_op, Pow):
482467
# We only allow for the base to be measurable
483468
if measurable_input_idx != 0:
@@ -495,11 +480,27 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
495480
transform = LocTransform(
496481
transform_args_fn=lambda *inputs: inputs[-1],
497482
)
498-
elif transform is None:
483+
elif isinstance(scalar_op, Mul):
499484
transform_inputs = (measurable_input, pt.mul(*other_inputs))
500485
transform = ScaleTransform(
501486
transform_args_fn=lambda *inputs: inputs[-1],
502487
)
488+
else:
489+
transform = {
490+
Exp: ExpTransform,
491+
Log: LogTransform,
492+
Abs: AbsTransform,
493+
Sinh: SinhTransform,
494+
Cosh: CoshTransform,
495+
Tanh: TanhTransform,
496+
ArcSinh: ArcsinhTransform,
497+
ArcCosh: ArccoshTransform,
498+
ArcTanh: ArctanhTransform,
499+
Erf: ErfTransform,
500+
Erfc: ErfcTransform,
501+
Erfcx: ErfcxTransform,
502+
}[type(scalar_op)]()
503+
503504
transform_op = MeasurableTransform(
504505
scalar_op=scalar_op,
505506
transform=transform,

0 commit comments

Comments
 (0)