@@ -463,21 +463,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
463
463
transform_inputs : tuple [TensorVariable , ...] = (measurable_input ,)
464
464
transform : Transform
465
465
466
- transform_dict = {
467
- Exp : ExpTransform (),
468
- Log : LogTransform (),
469
- Abs : AbsTransform (),
470
- Sinh : SinhTransform (),
471
- Cosh : CoshTransform (),
472
- Tanh : TanhTransform (),
473
- ArcSinh : ArcsinhTransform (),
474
- ArcCosh : ArccoshTransform (),
475
- ArcTanh : ArctanhTransform (),
476
- Erf : ErfTransform (),
477
- Erfc : ErfcTransform (),
478
- Erfcx : ErfcxTransform (),
479
- }
480
- transform = transform_dict .get (type (scalar_op ), None )
481
466
if isinstance (scalar_op , Pow ):
482
467
# We only allow for the base to be measurable
483
468
if measurable_input_idx != 0 :
@@ -495,11 +480,27 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
495
480
transform = LocTransform (
496
481
transform_args_fn = lambda * inputs : inputs [- 1 ],
497
482
)
498
- elif transform is None :
483
+ elif isinstance ( scalar_op , Mul ) :
499
484
transform_inputs = (measurable_input , pt .mul (* other_inputs ))
500
485
transform = ScaleTransform (
501
486
transform_args_fn = lambda * inputs : inputs [- 1 ],
502
487
)
488
+ else :
489
+ transform = {
490
+ Exp : ExpTransform ,
491
+ Log : LogTransform ,
492
+ Abs : AbsTransform ,
493
+ Sinh : SinhTransform ,
494
+ Cosh : CoshTransform ,
495
+ Tanh : TanhTransform ,
496
+ ArcSinh : ArcsinhTransform ,
497
+ ArcCosh : ArccoshTransform ,
498
+ ArcTanh : ArctanhTransform ,
499
+ Erf : ErfTransform ,
500
+ Erfc : ErfcTransform ,
501
+ Erfcx : ErfcxTransform ,
502
+ }[type (scalar_op )]()
503
+
503
504
transform_op = MeasurableTransform (
504
505
scalar_op = scalar_op ,
505
506
transform = transform ,
0 commit comments