@@ -1747,7 +1747,7 @@ def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]:
1747
1747
)
1748
1748
return {name : tuple (shape ) for name , shape in zip (names , f ())}
1749
1749
1750
- def check_start_vals (self , start ):
1750
+ def check_start_vals (self , start , ** kwargs ):
1751
1751
r"""Check that the starting values for MCMC do not cause the relevant log probability
1752
1752
to evaluate to something invalid (e.g. Inf or NaN)
1753
1753
@@ -1758,6 +1758,8 @@ def check_start_vals(self, start):
1758
1758
Defaults to ``trace.point(-1))`` if there is a trace provided and
1759
1759
``model.initial_point`` if not (defaults to empty dict). Initialization
1760
1760
methods for NUTS (see ``init`` keyword) can overwrite the default.
1761
+ Other keyword arguments :
1762
+ Any other keyword argument is sent to :py:meth:`~pymc.model.core.Model.point_logps`.
1761
1763
1762
1764
Raises
1763
1765
------
@@ -1787,7 +1789,7 @@ def check_start_vals(self, start):
1787
1789
f"Valid keys are: { valid_keys } , but { extra_keys } was supplied"
1788
1790
)
1789
1791
1790
- initial_eval = self .point_logps (point = elem )
1792
+ initial_eval = self .point_logps (point = elem , ** kwargs )
1791
1793
1792
1794
if not all (np .isfinite (v ) for v in initial_eval .values ()):
1793
1795
raise SamplingError (
@@ -1797,7 +1799,7 @@ def check_start_vals(self, start):
1797
1799
"You can call `model.debug()` for more details."
1798
1800
)
1799
1801
1800
- def point_logps (self , point = None , round_vals = 2 ):
1802
+ def point_logps (self , point = None , round_vals = 2 , ** kwargs ):
1801
1803
"""Computes the log probability of `point` for all random variables in the model.
1802
1804
1803
1805
Parameters
@@ -1807,6 +1809,8 @@ def point_logps(self, point=None, round_vals=2):
1807
1809
is used.
1808
1810
round_vals : int, default 2
1809
1811
Number of decimals to round log-probabilities.
1812
+ Other keyword arguments :
1813
+ Any other keyword argument are sent provided to :py:meth:`~pymc.model.core.Model.compile_fn`
1810
1814
1811
1815
Returns
1812
1816
-------
@@ -1822,7 +1826,7 @@ def point_logps(self, point=None, round_vals=2):
1822
1826
factor .name : np .round (np .asarray (factor_logp ), round_vals )
1823
1827
for factor , factor_logp in zip (
1824
1828
factors ,
1825
- self .compile_fn (factor_logps_fn )(point ),
1829
+ self .compile_fn (factor_logps_fn , ** kwargs )(point ),
1826
1830
)
1827
1831
}
1828
1832
0 commit comments