Skip to content

Commit 1abaca8

Browse files
committed
Allow inlining of Deterministics and Data in fgraph IR
1 parent a2ced28 commit 1abaca8

File tree

2 files changed

+83
-36
lines changed

2 files changed

+83
-36
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,41 @@ def test_basic():
7676
)
7777

7878

79-
def test_data():
79+
@pytest.mark.parametrize("inline_views", (False, True))
80+
def test_data(inline_views):
8081
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
8182
8283
Everything should be preserved across new and old models, except for shared RNGs
8384
"""
8485
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
8586
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
8687
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
87-
b0 = pm.ConstantData("b0", 0.0)
88+
b0 = pm.ConstantData("b0", np.zeros(3))
8889
b1 = pm.Normal("b1")
8990
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
9091
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
9192

92-
m_fgraph, memo = fgraph_from_model(m_old)
93+
m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views)
9394
assert isinstance(memo[x].owner.op, ModelNamed)
9495
assert isinstance(memo[y].owner.op, ModelNamed)
9596
assert isinstance(memo[b0].owner.op, ModelNamed)
97+
mu_inp = memo[mu].owner.inputs[0]
98+
obs = memo[obs]
99+
if not inline_views:
100+
# Add(b0, Mul(FreeRV(b1), x) not Add(Named(b0), Mul(FreeRV(b1), Named(x))
101+
assert mu_inp.owner.inputs[0] is memo[b0].owner.inputs[0]
102+
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0]
103+
# ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
104+
assert obs.owner.inputs[1] is memo[y].owner.inputs[0]
105+
else:
106+
assert mu_inp.owner.inputs[0] is memo[b0]
107+
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x]
108+
assert obs.owner.inputs[1] is memo[y]
96109

97110
m_new = model_from_fgraph(m_fgraph)
98111

99112
# ConstantData is preserved
100-
assert m_new["b0"].data == m_old["b0"].data
113+
assert np.all(m_new["b0"].data == m_old["b0"].data)
101114

102115
# Shared non-rng shared variables are preserved
103116
assert m_new["x"].container is x.container
@@ -114,7 +127,8 @@ def test_data():
114127
np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0])
115128

116129

117-
def test_deterministics():
130+
@pytest.mark.parametrize("inline_views", (False, True))
131+
def test_deterministics(inline_views):
118132
"""Test handling of deterministics.
119133
120134
We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome
@@ -140,22 +154,27 @@ def test_deterministics():
140154
assert m["y"].owner.inputs[3] is m["mu"]
141155
assert m["y"].owner.inputs[4] is not m["sigma"]
142156

143-
fg, _ = fgraph_from_model(m)
157+
fg, _ = fgraph_from_model(m, inlined_views=inline_views)
144158

145159
# Check that no Deterministics are in graph of x to y and y to z
146160
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
147161
# [Det(mu), Det(sigma)]
148162
mu = det_mu.owner.inputs[0]
149163
sigma = det_sigma.owner.inputs[0]
150-
# [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))]
151-
assert y.owner.inputs[0].owner.inputs[3] is mu
152164
assert y.owner.inputs[0].owner.inputs[4] is sigma
153-
# [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))]
154-
assert z.owner.inputs[0].owner.inputs[3] is y
155-
# [Det(y), Det(y)], not [Det(y), Det(Det(y))]
156-
assert det_y_.owner.inputs[0] is y
157-
assert det_y__.owner.inputs[0] is y
158165
assert det_y_ is not det_y__
166+
assert det_y_.owner.inputs[0] is y
167+
if not inline_views:
168+
# FreeRV(y(mu, sigma)) not FreeRV(y(Det(mu), Det(sigma)))
169+
assert y.owner.inputs[0].owner.inputs[3] is mu
170+
# FreeRV(z(y)) not FreeRV(z(Det(Det(y))))
171+
assert z.owner.inputs[0].owner.inputs[3] is y
172+
# Det(y), not Det(Det(y))
173+
assert det_y__.owner.inputs[0] is y
174+
else:
175+
assert y.owner.inputs[0].owner.inputs[3] is det_mu
176+
assert z.owner.inputs[0].owner.inputs[3] is det_y__
177+
assert det_y__.owner.inputs[0] is det_y_
159178

160179
# Both mu and sigma deterministics are now in the graph of x to y
161180
m = model_from_fgraph(fg)

pymc_experimental/utils/model_fgraph.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,16 @@ def model_free_rv(rv, value, transform, *dims):
9090

9191

9292
def toposort_replace(
93-
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]]
93+
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
9494
) -> None:
9595
"""Replace multiple variables in topological order."""
9696
toposort = fgraph.toposort()
9797
sorted_replacements = sorted(
98-
replacements, key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1
98+
replacements,
99+
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
100+
reverse=reverse,
99101
)
100-
fgraph.replace_all(tuple(sorted_replacements), import_missing=True)
102+
fgraph.replace_all(sorted_replacements, import_missing=True)
101103

102104

103105
@node_rewriter([Elemwise])
@@ -109,11 +111,20 @@ def local_remove_identity(fgraph, node):
109111
remove_identity_rewrite = out2in(local_remove_identity)
110112

111113

112-
def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
114+
def fgraph_from_model(
115+
model: Model, inlined_views=False
116+
) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
113117
"""Convert Model to FunctionGraph.
114118
115119
See: model_from_fgraph
116120
121+
Parameters
122+
----------
123+
model: PyMC model
124+
inlined_views: bool, default False
125+
Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph,
126+
or show up as separate branches.
127+
117128
Returns
118129
-------
119130
fgraph: FunctionGraph
@@ -138,19 +149,36 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
138149
free_rvs = model.free_RVs
139150
observed_rvs = model.observed_RVs
140151
potentials = model.potentials
152+
named_vars = model.named_vars.values()
141153
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
142154
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
143155
old_deterministics = model.deterministics
144-
deterministics = [det.copy(det.name) for det in old_deterministics]
145-
# Other variables that are in model.named_vars but are not any of the categories above
156+
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
157+
# Value variables (we also have to decide whether to inline named ones)
158+
old_value_vars = list(rvs_to_values.values())
159+
unnamed_value_vars = [val for val in old_value_vars if val not in named_vars]
160+
named_value_vars = [
161+
val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars
162+
]
163+
value_vars = old_value_vars.copy()
164+
if inlined_views:
165+
# In this case we want to use the named_value_vars as the value_vars in RVs
166+
for named_val in named_value_vars:
167+
idx = value_vars.index(named_val)
168+
value_vars[idx] = named_val
169+
# Other variables that are in named_vars but are not any of the categories above
146170
# E.g., MutableData, ConstantData, _dim_lengths
147171
# We use the same trick as deterministics!
148-
accounted_for = free_rvs + observed_rvs + potentials + old_deterministics
149-
old_other_named_vars = [var for var in model.named_vars.values() if var not in accounted_for]
150-
other_named_vars = [var.copy(var.name) for var in old_other_named_vars]
151-
value_vars = [val for val in rvs_to_values.values() if val not in old_other_named_vars]
172+
accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars)
173+
other_named_vars = [
174+
var if inlined_views else var.copy(var.name)
175+
for var in named_vars
176+
if var not in accounted_for
177+
]
152178

153-
model_vars = rvs + potentials + deterministics + other_named_vars + value_vars
179+
model_vars = (
180+
rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars
181+
)
154182

155183
memo = {}
156184

@@ -176,13 +204,13 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
176204

177205
# Introduce dummy `ModelVar` Ops
178206
free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()}
179-
free_rvs_to_values = {memo[k]: memo[v] for k, v in rvs_to_values.items() if k in free_rvs}
207+
free_rvs_to_values = {memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in free_rvs}
180208
observed_rvs_to_values = {
181-
memo[k]: memo[v] for k, v in rvs_to_values.items() if k in observed_rvs
209+
memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in observed_rvs
182210
}
183211
potentials = [memo[k] for k in potentials]
184212
deterministics = [memo[k] for k in deterministics]
185-
other_named_vars = [memo[k] for k in other_named_vars]
213+
named_vars = [memo[k] for k in other_named_vars + named_value_vars]
186214

187215
vars = fgraph.outputs
188216
new_vars = []
@@ -198,31 +226,31 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
198226
new_var = model_potential(var, *dims)
199227
elif var in deterministics:
200228
new_var = model_deterministic(var, *dims)
201-
elif var in other_named_vars:
229+
elif var in named_vars:
202230
new_var = model_named(var, *dims)
203231
else:
204-
# Value variables
232+
# Unnamed value variables
205233
new_var = var
206234
new_vars.append(new_var)
207235

208236
replacements = tuple(zip(vars, new_vars))
209-
toposort_replace(fgraph, replacements)
237+
toposort_replace(fgraph, replacements, reverse=True)
210238

211239
# Reference model vars in memo
212240
inverse_memo = {v: k for k, v in memo.items()}
213241
for var, model_var in replacements:
214-
if isinstance(
215-
model_var.owner is not None and model_var.owner.op, (ModelDeterministic, ModelNamed)
242+
if not inlined_views and (
243+
model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed))
216244
):
217245
# Ignore extra identity that will be removed at the end
218246
var = var.owner.inputs[0]
219247
original_var = inverse_memo[var]
220248
memo[original_var] = model_var
221249

222-
# Remove value variable as outputs, now that they are graph inputs
223-
first_value_idx = len(fgraph.outputs) - len(value_vars)
224-
for _ in value_vars:
225-
fgraph.remove_output(first_value_idx)
250+
# Remove the last outputs corresponding to unnamed value variables, now that they are graph inputs
251+
first_idx_to_remove = len(fgraph.outputs) - len(unnamed_value_vars)
252+
for _ in unnamed_value_vars:
253+
fgraph.remove_output(first_idx_to_remove)
226254

227255
# Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
228256
remove_identity_rewrite.apply(fgraph)

0 commit comments

Comments
 (0)