Skip to content

Implement Model.debug() helper #6634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,26 @@ def get_tau_sigma(tau=None, sigma=None):
tau = 1.0
else:
if isinstance(sigma, Variable):
sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0")
# Keep tau negative, if sigma was negative, so that it will fail when used
tau = (sigma**-2.0) * pt.sgn(sigma)
else:
sigma_ = np.asarray(sigma)
if np.any(sigma_ <= 0):
raise ValueError("sigma must be positive")
tau = sigma_**-2.0
tau = sigma_**-2.0

else:
if sigma is not None:
raise ValueError("Can't pass both tau and sigma")
else:
if isinstance(tau, Variable):
tau_ = check_parameters(tau, tau > 0, msg="tau > 0")
# Keep sigma negative, if tau was negative, so that it will fail when used
sigma = pt.abs(tau) ** (-0.5) * pt.sgn(tau)
else:
tau_ = np.asarray(tau)
if np.any(tau_ <= 0):
raise ValueError("tau must be positive")
sigma = tau_**-0.5
sigma = tau_**-0.5

return floatX(tau), floatX(sigma)

Expand Down
154 changes: 153 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import sys
import threading
import types
import warnings
Expand All @@ -24,6 +25,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Expand All @@ -39,13 +41,15 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor.compile import DeepCopyOp, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.var import TensorConstant, TensorVariable

Expand All @@ -61,6 +65,7 @@
)
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import joint_logp
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
PointFunc,
SeedSequenceSeed,
Expand Down Expand Up @@ -1779,7 +1784,8 @@ def check_start_vals(self, start):
raise SamplingError(
"Initial evaluation of model at starting point failed!\n"
f"Starting values:\n{elem}\n\n"
f"Logp initial evaluation results:\n{initial_eval}"
f"Logp initial evaluation results:\n{initial_eval}\n"
"You can call `model.debug()` for more details."
)

def point_logps(self, point=None, round_vals=2):
Expand Down Expand Up @@ -1811,6 +1817,152 @@ def point_logps(self, point=None, round_vals=2):
)
}

def debug(
self,
point: Optional[Dict[str, np.ndarray]] = None,
fn: Literal["logp", "dlogp", "random"] = "logp",
verbose: bool = False,
):
"""Debug model function at point.

The method will evaluate the `fn` for each variable at a time.
When an evaluation fails or produces a non-finite value we print:
1. The graph of the parameters
2. The value of the parameters (if those can be evaluated)
3. The output of `fn` (if it can be evaluated)

This function should help to quickly narrow down invalid parametrizations.

Parameters
----------
point : Point
Point at which model function should be evaluated
fn : str, default "logp"
Function to be used for debugging. Can be one of [logp, dlogp, random].
verbose : bool, default False
Whether to show a more verbose PyTensor output when function cannot be evaluated
"""
print_ = functools.partial(print, file=sys.stdout)

def first_line(exc):
return exc.args[0].split("\n")[0]

def debug_parameters(rv):
if isinstance(rv.owner.op, RandomVariable):
inputs = rv.owner.inputs[3:]
else:
inputs = [inp for inp in rv.owner.inputs if not isinstance(inp.type, RandomType)]
rv_inputs = pytensor.function(
self.value_vars,
self.replace_rvs_by_values(inputs),
on_unused_input="ignore",
mode=get_mode(None).excluding("inplace", "fusion"),
)

print_(f"The variable {rv} has the following parameters:")
# done and used_ids are used to keep the same ids across distinct dprint calls
done = {}
used_ids = {}
for i, out in enumerate(rv_inputs.maker.fgraph.outputs):
print_(f"{i}: ", end=""),
# Don't print useless deepcopys
if out.owner and isinstance(out.owner.op, DeepCopyOp):
out = out.owner.inputs[0]
pytensor.dprint(out, print_type=True, done=done, used_ids=used_ids)

try:
print_("The parameters evaluate to:")
for i, rv_input_eval in enumerate(rv_inputs(**point)):
print_(f"{i}: {rv_input_eval}")
except Exception as exc:
print_(
f"The parameters of the variable {rv} cannot be evaluated: {first_line(exc)}"
)
if verbose:
print_(exc, "\n")

if fn not in ("logp", "dlogp", "random"):
raise ValueError(f"fn must be one of [logp, dlogp, random], got {fn}")

if point is None:
point = self.initial_point()
print_(f"point={point}\n")

rvs_to_check = list(self.basic_RVs)
if fn in ("logp", "dlogp"):
rvs_to_check += [self.replace_rvs_by_values(p) for p in self.potentials]

found_problem = False
for rv in rvs_to_check:
if fn == "logp":
rv_fn = pytensor.function(
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
)
elif fn == "dlogp":
rv_fn = pytensor.function(
self.value_vars, self.dlogp(vars=rv), on_unused_input="ignore"
)
else:
[rv_inputs_replaced] = replace_rvs_by_values(
[rv],
# Don't include itself, or the function will just the the value variable
rvs_to_values={
rv_key: value
for rv_key, value in self.rvs_to_values.items()
if rv_key is not rv
},
rvs_to_transforms=self.rvs_to_transforms,
)
rv_fn = pytensor.function(
self.value_vars, rv_inputs_replaced, on_unused_input="ignore"
)

try:
rv_fn_eval = rv_fn(**point)
except ParameterValueError as exc:
found_problem = True
debug_parameters(rv)
print_(
f"This does not respect one of the following constraints: {first_line(exc)}\n"
)
if verbose:
print_(exc)
except Exception as exc:
found_problem = True
debug_parameters(rv)
print_(
f"The variable {rv} {fn} method raised the following exception: {first_line(exc)}\n"
)
if verbose:
print_(exc)
else:
if not np.all(np.isfinite(rv_fn_eval)):
found_problem = True
debug_parameters(rv)
if fn == "random" or rv is self.potentials:
print_("This combination seems able to generate non-finite values")
else:
# Find which values are associated with non-finite evaluation
values = self.rvs_to_values[rv]
if rv in self.observed_RVs:
values = values.eval()
else:
values = point[values.name]

observed = " observed " if rv in self.observed_RVs else " "
print_(
f"Some of the{observed}values of variable {rv} are associated with a non-finite {fn}:"
)
mask = ~np.isfinite(rv_fn_eval)
for value, fn_eval in zip(values[mask], rv_fn_eval[mask]):
print_(f" value = {value} -> {fn} = {fn_eval}")
print_()

if not found_problem:
print_("No problems found")
elif not verbose:
print_("You can set `verbose=True` for more details")


# this is really disgusting, but it breaks a self-loop: I can't pass Model
# itself as context class init arg.
Expand Down
31 changes: 17 additions & 14 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools as ft
import warnings

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -890,24 +891,26 @@ def scipy_logp(value, mu, sigma, lower, upper):
assert np.isinf(logp[2])

def test_get_tau_sigma(self):
sigma = np.array(2)
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
# Fail on warnings
with warnings.catch_warnings():
warnings.simplefilter("error")

tau = np.array(2)
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
sigma = np.array(2)
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])

tau, _ = get_tau_sigma(sigma=pt.constant(-2))
with pytest.raises(ParameterValueError):
tau.eval()
tau = np.array(2)
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])

_, sigma = get_tau_sigma(tau=pt.constant(-2))
with pytest.raises(ParameterValueError):
sigma.eval()
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
npt.assert_almost_equal(tau.eval(), -0.25)

sigma = [1, 2]
npt.assert_almost_equal(
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
)
_, sigma = get_tau_sigma(tau=pt.constant(-2))
npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2))

sigma = [1, 2]
npt.assert_almost_equal(
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
)

@pytest.mark.parametrize(
"value,mu,sigma,nu,logp",
Expand Down
72 changes: 72 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import scipy.stats as st

from pytensor.graph import graph_inputs
from pytensor.raise_op import Assert, assert_op
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import ScalarSharedVariable
Expand Down Expand Up @@ -1553,3 +1554,74 @@ def test_tag_future_warning_model():
assert y_value.eval() == 5

assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)


class TestModelDebug:
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_no_problems(self, fn, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
m.debug(fn=fn)

out, _ = capfd.readouterr()
assert out == "point={'x': array([ 1., -1., 1.])}\n\nNo problems found\n"

@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_invalid_parameter(self, fn, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
y = pm.HalfNormal("y", tau=x)
m.debug(fn=fn)

out, _ = capfd.readouterr()
if fn == "dlogp":
# var dlogp is 0 or 1 without a likelihood
assert "No problems found" in out
else:
assert "The parameters evaluate to:\n0: 0.0\n1: [ 1. -1. 1.]" in out
if fn == "logp":
assert "This does not respect one of the following constraints: sigma > 0" in out
else:
assert (
"The variable y random method raised the following exception: Domain error in arguments."
in out
)

@pytest.mark.parametrize("verbose", (True, False))
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
def test_invalid_parameter_cant_be_evaluated(self, fn, verbose, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, 1, 1])
sigma = Assert(msg="x > 0")(pm.math.abs(x), (x > 0).all())
y = pm.HalfNormal("y", sigma=sigma)
m.debug(point={"x": [-1, -1, -1], "y_log__": [0, 0, 0]}, fn=fn, verbose=verbose)

out, _ = capfd.readouterr()
assert "{'x': [-1, -1, -1], 'y_log__': [0, 0, 0]}" in out
assert "The parameters of the variable y cannot be evaluated: x > 0" in out
verbose_str = "Apply node that caused the error:" in out
assert verbose_str if verbose else not verbose_str

def test_invalid_value(self, capfd):
with pm.Model() as m:
x = pm.Normal("x", [1, -1, 1])
y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], transform=None)
m.debug()

out, _ = capfd.readouterr()
assert "The parameters of the variable y evaluate to:\n0: array(0., dtype=float32)\n1: array([1., 1., 1.])]"
assert "Some of the values of variable y are associated with a non-finite logp" in out
assert "value = -1.0 -> logp = -inf" in out

def test_invalid_observed_value(self, capfd):
with pm.Model() as m:
theta = pm.Uniform("theta", lower=0, upper=1)
y = pm.Uniform("y", lower=0, upper=theta, observed=[0.49, 0.27, 0.53, 0.19])
m.debug()

out, _ = capfd.readouterr()
assert "The parameters of the variable y evaluate to:\n0: 0.0\n1: 0.5"
assert (
"Some of the observed values of variable y are associated with a non-finite logp" in out
)
assert "value = 0.53 -> logp = -inf" in out