diff --git a/pymc/model/core.py b/pymc/model/core.py index b85cc802f..d04011407 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -32,7 +32,7 @@ import pytensor.tensor as pt import scipy.sparse as sps -from pytensor.compile import DeepCopyOp, Function, get_mode +from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs from pytensor.tensor.random.op import RandomVariable @@ -1657,7 +1657,15 @@ def compile_fn( return PointFunc(fn) return fn - def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs): + def profile( + self, + outs, + *, + n=1000, + point=None, + profile=True, + **compile_fn_kwargs, + ) -> ProfileStats: """Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument. Parameters @@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs): point : Point Point to pass to the function profile : True or ProfileStats - args, kwargs - Compilation args + compile_fn_kwargs + Compilation kwargs for :func:`pymc.model.core.Model.compile_fn` Returns ------- - ProfileStats + pytensor.compile.profiling.ProfileStats Use .summary() to print stats. """ - kwargs.setdefault("on_unused_input", "ignore") - f = self.compile_fn(outs, inputs=self.value_vars, point_fn=False, profile=profile, **kwargs) + compile_fn_kwargs.setdefault("on_unused_input", "ignore") + f = self.compile_fn( + outs, + inputs=self.value_vars, + point_fn=False, + profile=profile, + **compile_fn_kwargs, + ) if point is None: point = self.initial_point()