Skip to content

Commit b06d6c3

Browse files
committed
Add test for scan logp with multiple valued output types
1 parent db0b218 commit b06d6c3

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/logprob/test_scan.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
36+
import itertools
3637

3738
import numpy as np
3839
import pytensor
@@ -502,3 +503,51 @@ def ref_logp(values, rho, sigma):
502503
logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
503504
ref_logp(ma2_test, rho_test, sigma_test),
504505
)
506+
507+
508+
@pytest.mark.xfail(reason="Not implemented yet")
509+
def test_scan_multiple_output_types():
510+
"""Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs."""
511+
[xs, ys, zs], _ = pytensor.scan(
512+
fn=lambda x_mu, y_tm1, z_tm2, z_tm1: (
513+
pt.random.normal(x_mu),
514+
pt.random.normal(y_tm1),
515+
pt.random.normal(z_tm1) + z_tm2,
516+
),
517+
sequences=[pt.arange(10)],
518+
outputs_info=[
519+
None,
520+
pt.zeros(()),
521+
dict(initial=pt.ones(2), taps=[-2, -1]),
522+
],
523+
)
524+
525+
xs.name = "xs"
526+
xs_value = xs.clone()
527+
ys.name = "ys"
528+
ys_value = ys.clone()
529+
zs.name = "zs"
530+
zs_value = zs.clone()
531+
532+
logp_dict = conditional_logp({xs: xs_value, ys: ys_value, zs: zs_value})
533+
xs_logp = logp_dict[xs_value]
534+
ys_logp = logp_dict[ys_value]
535+
zs_logp = logp_dict[zs_value]
536+
537+
assert_no_rvs([xs_logp, ys_logp, zs_logp])
538+
fn = pytensor.function(
539+
[xs_value, ys_value, zs_value],
540+
[xs_logp, ys_logp, zs_logp],
541+
)
542+
543+
rng = np.random.default_rng(577)
544+
test_value = rng.uniform(size=(10,))
545+
(xs_logp_eval, ys_logp_eval, zs_logp_eval) = fn(test_value, test_value, test_value)
546+
np.testing.assert_allclose(xs_logp_eval, stats.norm.logpdf(test_value, np.arange(10)))
547+
np.testing.assert_allclose(ys_logp_eval, stats.norm.logpdf(test_value, [0, *test_value[:-1]]))
548+
np.testing.assert_allclose(
549+
zs_logp_eval,
550+
stats.norm.logpdf(
551+
test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])]
552+
),
553+
)

0 commit comments

Comments
 (0)