Skip to content

Commit e71d1cb

Browse files
committed
Make DiracDelta a SymbolicRV
1 parent 057a6d2 commit e71d1cb

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

pymc/distributions/distribution.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,11 @@ def change_custom_dist_size(op, rv, new_size, expand):
10481048

10491049
return new_rv
10501050

1051-
rngs, rngs_updates = zip(*dummy_updates_dict.items())
1051+
if dummy_updates_dict:
1052+
rngs, rngs_updates = zip(*dummy_updates_dict.items())
1053+
else:
1054+
rngs, rngs_updates = (), ()
1055+
10521056
inputs = [*dummy_params, *rngs]
10531057
outputs = [dummy_rv, *rngs_updates]
10541058
signature = cls._infer_final_signature(
@@ -1497,19 +1501,26 @@ def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False
14971501
)
14981502

14991503

1500-
class DiracDeltaRV(RandomVariable):
1504+
class DiracDeltaRV(SymbolicRandomVariable):
15011505
name = "diracdelta"
1502-
signature = "()->()"
1506+
extended_signature = "[size],()->()"
15031507
_print_name = ("DiracDelta", "\\operatorname{DiracDelta}")
15041508

1509+
def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
1510+
# Because the distribution does not have RNGs we have to prevent constant-folding
1511+
return False
1512+
15051513
@classmethod
1506-
def rng_fn(cls, rng, c, size=None):
1507-
if size is None:
1508-
return c.copy()
1509-
return np.full(size, c)
1514+
def rv_op(cls, c, *, size=None, rng=None):
1515+
size = normalize_size_param(size)
1516+
c = pt.as_tensor(c)
15101517

1518+
if rv_size_is_none(size):
1519+
out = c.copy()
1520+
else:
1521+
out = pt.full(size, c)
15111522

1512-
diracdelta = DiracDeltaRV()
1523+
return cls(inputs=[size, c], outputs=[out])(size, c)
15131524

15141525

15151526
class DiracDelta(Discrete):
@@ -1524,14 +1535,15 @@ class DiracDelta(Discrete):
15241535
that use DiracDelta, such as Mixtures.
15251536
"""
15261537

1527-
rv_op = diracdelta
1538+
rv_type = DiracDeltaRV
1539+
rv_op = DiracDeltaRV.rv_op
15281540

15291541
@classmethod
15301542
def dist(cls, c, *args, **kwargs):
15311543
c = pt.as_tensor_variable(c)
15321544
if c.dtype in continuous_types:
15331545
c = floatX(c)
1534-
return super().dist([c], dtype=c.dtype, **kwargs)
1546+
return super().dist([c], **kwargs)
15351547

15361548
def support_point(rv, size, c):
15371549
if not rv_size_is_none(size):

0 commit comments

Comments
 (0)