|
33 | 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
34 | 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
35 | 35 | # SOFTWARE.
|
| 36 | +import itertools |
36 | 37 |
|
37 | 38 | import numpy as np
|
38 | 39 | import pytensor
|
@@ -502,3 +503,51 @@ def ref_logp(values, rho, sigma):
|
502 | 503 | logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
|
503 | 504 | ref_logp(ma2_test, rho_test, sigma_test),
|
504 | 505 | )
|
| 506 | + |
| 507 | + |
| 508 | +@pytest.mark.xfail(reason="Not implemented yet") |
| 509 | +def test_scan_multiple_output_types(): |
| 510 | + """Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs.""" |
| 511 | + [xs, ys, zs], _ = pytensor.scan( |
| 512 | + fn=lambda x_mu, y_tm1, z_tm2, z_tm1: ( |
| 513 | + pt.random.normal(x_mu), |
| 514 | + pt.random.normal(y_tm1), |
| 515 | + pt.random.normal(z_tm1) + z_tm2, |
| 516 | + ), |
| 517 | + sequences=[pt.arange(10)], |
| 518 | + outputs_info=[ |
| 519 | + None, |
| 520 | + pt.zeros(()), |
| 521 | + dict(initial=pt.ones(2), taps=[-2, -1]), |
| 522 | + ], |
| 523 | + ) |
| 524 | + |
| 525 | + xs.name = "xs" |
| 526 | + xs_value = xs.clone() |
| 527 | + ys.name = "ys" |
| 528 | + ys_value = ys.clone() |
| 529 | + zs.name = "zs" |
| 530 | + zs_value = zs.clone() |
| 531 | + |
| 532 | + logp_dict = conditional_logp({xs: xs_value, ys: ys_value, zs: zs_value}) |
| 533 | + xs_logp = logp_dict[xs_value] |
| 534 | + ys_logp = logp_dict[ys_value] |
| 535 | + zs_logp = logp_dict[zs_value] |
| 536 | + |
| 537 | + assert_no_rvs([xs_logp, ys_logp, zs_logp]) |
| 538 | + fn = pytensor.function( |
| 539 | + [xs_value, ys_value, zs_value], |
| 540 | + [xs_logp, ys_logp, zs_logp], |
| 541 | + ) |
| 542 | + |
| 543 | + rng = np.random.default_rng(577) |
| 544 | + test_value = rng.uniform(size=(10,)) |
| 545 | + (xs_logp_eval, ys_logp_eval, zs_logp_eval) = fn(test_value, test_value, test_value) |
| 546 | + np.testing.assert_allclose(xs_logp_eval, stats.norm.logpdf(test_value, np.arange(10))) |
| 547 | + np.testing.assert_allclose(ys_logp_eval, stats.norm.logpdf(test_value, [0, *test_value[:-1]])) |
| 548 | + np.testing.assert_allclose( |
| 549 | + zs_logp_eval, |
| 550 | + stats.norm.logpdf( |
| 551 | + test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])] |
| 552 | + ), |
| 553 | + ) |
0 commit comments