Skip to content

Tweak Model.profile docstring and type hint #7795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -1657,7 +1657,15 @@ def compile_fn(
return PointFunc(fn)
return fn

def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
def profile(
self,
outs,
*,
n=1000,
point=None,
profile=True,
**compile_fn_kwargs,
) -> ProfileStats:
"""Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument.

Parameters
Expand All @@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
point : Point
Point to pass to the function
profile : True or ProfileStats
args, kwargs
Compilation args
compile_fn_kwargs
Compilation kwargs for :func:`pymc.model.core.Model.compile_fn`

Returns
-------
ProfileStats
pytensor.compile.profiling.ProfileStats
Use .summary() to print stats.
"""
kwargs.setdefault("on_unused_input", "ignore")
f = self.compile_fn(outs, inputs=self.value_vars, point_fn=False, profile=profile, **kwargs)
compile_fn_kwargs.setdefault("on_unused_input", "ignore")
f = self.compile_fn(
outs,
inputs=self.value_vars,
point_fn=False,
profile=profile,
**compile_fn_kwargs,
)
if point is None:
point = self.initial_point()

Expand Down