Skip to content

Commit 2cdb87a

Browse files
ricardoV94twiecki
authored andcommitted
Use fast_compile in model_graph
1 parent f6f1a8e commit 2cdb87a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pymc/model_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import defaultdict, deque
1717
from typing import Dict, Iterator, NewType, Optional, Set
1818

19+
from aesara import function
1920
from aesara.compile.sharedvalue import SharedVariable
2021
from aesara.graph.basic import walk
2122
from aesara.tensor.random.op import RandomVariable
@@ -159,6 +160,9 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
159160

160161
graph.node(var_name.replace(":", "&"), **kwargs)
161162

163+
def _eval(self, var):
164+
return function([], var, mode="FAST_COMPILE")()
165+
162166
def get_plates(self):
163167
"""Rough but surprisingly accurate plate detection.
164168
@@ -174,11 +178,11 @@ def get_plates(self):
174178
v = self.model[var_name]
175179
if var_name in self.model.RV_dims:
176180
plate_label = " x ".join(
177-
f"{d} ({self.model.dim_lengths[d].eval()})"
181+
f"{d} ({self._eval(self.model.dim_lengths[d])})"
178182
for d in self.model.RV_dims[var_name]
179183
)
180184
else:
181-
plate_label = " x ".join(map(str, v.shape.eval()))
185+
plate_label = " x ".join(map(str, self._eval(v.shape)))
182186
plates[plate_label].add(var_name)
183187
return plates
184188

0 commit comments

Comments
 (0)