@@ -1048,7 +1048,11 @@ def change_custom_dist_size(op, rv, new_size, expand):
1048
1048
1049
1049
return new_rv
1050
1050
1051
- rngs , rngs_updates = zip (* dummy_updates_dict .items ())
1051
+ if dummy_updates_dict :
1052
+ rngs , rngs_updates = zip (* dummy_updates_dict .items ())
1053
+ else :
1054
+ rngs , rngs_updates = (), ()
1055
+
1052
1056
inputs = [* dummy_params , * rngs ]
1053
1057
outputs = [dummy_rv , * rngs_updates ]
1054
1058
signature = cls ._infer_final_signature (
@@ -1497,19 +1501,26 @@ def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False
1497
1501
)
1498
1502
1499
1503
1500
- class DiracDeltaRV (RandomVariable ):
1504
+ class DiracDeltaRV (SymbolicRandomVariable ):
1501
1505
name = "diracdelta"
1502
- signature = "()->()"
1506
+ extended_signature = "[size], ()->()"
1503
1507
_print_name = ("DiracDelta" , "\\ operatorname{DiracDelta}" )
1504
1508
1509
+ def do_constant_folding (self , fgraph : "FunctionGraph" , node : Apply ) -> bool :
1510
+ # Because the distribution does not have RNGs we have to prevent constant-folding
1511
+ return False
1512
+
1505
1513
@classmethod
1506
- def rng_fn (cls , rng , c , size = None ):
1507
- if size is None :
1508
- return c .copy ()
1509
- return np .full (size , c )
1514
+ def rv_op (cls , c , * , size = None , rng = None ):
1515
+ size = normalize_size_param (size )
1516
+ c = pt .as_tensor (c )
1510
1517
1518
+ if rv_size_is_none (size ):
1519
+ out = c .copy ()
1520
+ else :
1521
+ out = pt .full (size , c )
1511
1522
1512
- diracdelta = DiracDeltaRV ( )
1523
+ return cls ( inputs = [ size , c ], outputs = [ out ])( size , c )
1513
1524
1514
1525
1515
1526
class DiracDelta (Discrete ):
@@ -1524,14 +1535,15 @@ class DiracDelta(Discrete):
1524
1535
that use DiracDelta, such as Mixtures.
1525
1536
"""
1526
1537
1527
- rv_op = diracdelta
1538
+ rv_type = DiracDeltaRV
1539
+ rv_op = DiracDeltaRV .rv_op
1528
1540
1529
1541
@classmethod
1530
1542
def dist (cls , c , * args , ** kwargs ):
1531
1543
c = pt .as_tensor_variable (c )
1532
1544
if c .dtype in continuous_types :
1533
1545
c = floatX (c )
1534
- return super ().dist ([c ], dtype = c . dtype , ** kwargs )
1546
+ return super ().dist ([c ], ** kwargs )
1535
1547
1536
1548
def support_point (rv , size , c ):
1537
1549
if not rv_size_is_none (size ):
0 commit comments