Skip to content

Commit 66f61d0

Browse files
committed
Transform to remove Minibatch from model
1 parent 2705a5e commit 66f61d0

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

pymc/model/transform/basic.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from collections.abc import Sequence
1515

1616
from pytensor import Variable
17-
from pytensor.graph import ancestors
17+
from pytensor.graph import ancestors, node_rewriter
18+
from pytensor.graph.rewriting.basic import out2in
1819

20+
from pymc.data import MinibatchOp
1921
from pymc.model.core import Model
2022
from pymc.model.fgraph import (
2123
ModelObservedRV,
@@ -58,3 +60,16 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
5860
else:
5961
vars_seq = (vars,)
6062
return [model[var] if isinstance(var, str) else var for var in vars_seq]
63+
64+
65+
def remove_minibatched_nodes(model: Model):
66+
"""Remove all uses of pm.Minibatch in the Model."""
67+
68+
@node_rewriter([MinibatchOp])
69+
def local_remove_minibatch(fgraph, node):
70+
return node.inputs
71+
72+
remove_minibatch = out2in(local_remove_minibatch)
73+
fgraph, _ = fgraph_from_model(model)
74+
remove_minibatch.apply(fgraph)
75+
return model_from_fgraph(fgraph, mutate_fgraph=True)

tests/model/transform/test_basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
1416
import pymc as pm
1517

16-
from pymc.model.transform.basic import prune_vars_detached_from_observed
18+
from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes
1719

1820

1921
def test_prune_vars_detached_from_observed():
@@ -30,3 +32,17 @@ def test_prune_vars_detached_from_observed():
3032
assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"}
3133
pruned_m = prune_vars_detached_from_observed(m)
3234
assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"}
35+
36+
37+
def test_remove_minibatches():
38+
data_size = 100
39+
data = np.zeros((data_size,))
40+
batch_size = 10
41+
with pm.Model() as m1:
42+
mb = pm.Minibatch(data, batch_size=batch_size)
43+
x = pm.Normal("x")
44+
y = pm.Normal("y", x, observed=mb, total_size=100)
45+
46+
m2 = remove_minibatched_nodes(m1)
47+
assert m1.y.shape[0].eval() == batch_size
48+
assert m2.y.shape[0].eval() == data_size

0 commit comments

Comments
 (0)