35
35
# SOFTWARE.
36
36
37
37
import abc
38
+ import warnings
38
39
39
40
from copy import copy
40
41
from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
111
112
cleanup_ir_rewrites_db ,
112
113
measurable_ir_rewrites_db ,
113
114
)
115
+ from pymc .logprob .shape import measurable_broadcast
114
116
from pymc .logprob .utils import check_potential_measurability
115
117
116
118
@@ -564,10 +566,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
564
566
565
567
scalar_op = node .op .scalar_op
566
568
measurable_input_idx = 0
569
+ measurable_input_broadcast = (
570
+ measurable_input .type .broadcastable != node .default_output ().type .broadcastable
571
+ )
567
572
transform_inputs : Tuple [TensorVariable , ...] = (measurable_input ,)
568
573
transform : RVTransform
569
574
570
- transform_dict = {
575
+ unary_transforms_dict = {
571
576
Exp : ExpTransform (),
572
577
Log : LogTransform (),
573
578
Abs : AbsTransform (),
@@ -581,29 +586,49 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
581
586
Erfc : ErfcTransform (),
582
587
Erfcx : ErfcxTransform (),
583
588
}
584
- transform = transform_dict .get (type (scalar_op ), None )
585
- if isinstance (scalar_op , Pow ):
586
- # We only allow for the base to be measurable
587
- if measurable_input_idx != 0 :
588
- return None
589
- try :
590
- (power ,) = other_inputs
591
- power = pt .get_underlying_scalar_constant_value (power ).item ()
592
- # Power needs to be a constant
593
- except NotScalarConstantError :
589
+ transform = unary_transforms_dict .get (type (scalar_op ), None )
590
+ if transform is None :
591
+ if isinstance (scalar_op , Pow ):
592
+ # We only allow for the base to be measurable
593
+ if measurable_input_idx != 0 :
594
+ return None
595
+ try :
596
+ (power ,) = other_inputs
597
+ base_power = pt .get_underlying_scalar_constant_value (power ).item ()
598
+ # Power needs to be a constant
599
+ except NotScalarConstantError :
600
+ return None
601
+ transform_inputs = (measurable_input , power )
602
+ transform = PowerTransform (power = base_power )
603
+ elif isinstance (scalar_op , Add ):
604
+ transform_inputs = (measurable_input , pt .add (* other_inputs ))
605
+ transform = LocTransform (
606
+ transform_args_fn = lambda * inputs : inputs [- 1 ],
607
+ )
608
+ elif isinstance (scalar_op , Mul ):
609
+ transform_inputs = (measurable_input , pt .mul (* other_inputs ))
610
+ transform = ScaleTransform (
611
+ transform_args_fn = lambda * inputs : inputs [- 1 ],
612
+ )
613
+ else :
614
+ raise TypeError (
615
+ f"Scalar Op not supported: { scalar_op } . Rewrite should not have been triggered"
616
+ ) # pragma: no cover
617
+
618
+ if measurable_input_broadcast :
619
+ # This rewrite logic only supports broadcasting for transforms with two inputs, where the first is measurable.
620
+ # This covers all current cases, update if other cases are supported in the future.
621
+ if len (transform_inputs ) != 2 or measurable_input_idx != 0 :
594
622
return None
595
- transform_inputs = (measurable_input , power )
596
- transform = PowerTransform (power = power )
597
- elif isinstance (scalar_op , Add ):
598
- transform_inputs = (measurable_input , pt .add (* other_inputs ))
599
- transform = LocTransform (
600
- transform_args_fn = lambda * inputs : inputs [- 1 ],
601
- )
602
- elif transform is None :
603
- transform_inputs = (measurable_input , pt .mul (* other_inputs ))
604
- transform = ScaleTransform (
605
- transform_args_fn = lambda * inputs : inputs [- 1 ],
623
+ warnings .warn (
624
+ "MeasurableTransform with implicit broadcasting detected. This corresponds to a potentially degenerate probability graph.\n "
625
+ "If you did not intend this, make sure the base measurable variable is created with all the dimensions from the start."
626
+ "Otherwise, an explicit `broadcast_to` operation can be used to silence this warning.\n " ,
627
+ UserWarning ,
606
628
)
629
+ measurable_input , other_input = transform_inputs
630
+ measurable_input = measurable_broadcast (measurable_input , other_input .shape )
631
+ transform_inputs = (measurable_input , other_input )
607
632
608
633
transform_op = MeasurableTransform (
609
634
scalar_op = scalar_op ,
0 commit comments