|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import warnings |
| 15 | + |
14 | 16 | from unittest import mock
|
15 | 17 |
|
16 | 18 | import numpy as np
|
|
38 | 40 | from pymc.exceptions import NotConstantValueError
|
39 | 41 | from pymc.logprob.utils import ParameterValueError
|
40 | 42 | from pymc.pytensorf import (
|
| 43 | + collect_default_updates, |
41 | 44 | compile_pymc,
|
42 | 45 | constant_fold,
|
43 | 46 | convert_observed_data,
|
@@ -406,28 +409,63 @@ def test_compile_pymc_updates_inputs(self):
|
406 | 409 | # Each RV adds a shared output for its rng
|
407 | 410 | assert len(fn_fgraph.outputs) == 1 + rvs_in_graph
|
408 | 411 |
|
409 |
| - # Disable `reseed_rngs` so that we can test with simpler update rule |
410 |
| - @mock.patch("pymc.pytensorf.reseed_rngs") |
411 |
| - def test_compile_pymc_custom_update_op(self, _): |
412 |
| - """Test that custom MeasurableVariable Op updates are used by compile_pymc""" |
| 412 | + def test_compile_pymc_symbolic_rv_update(self): |
| 413 | + """Test that SymbolicRandomVariable Op update methods are used by compile_pymc""" |
413 | 414 |
|
414 | 415 | class NonSymbolicRV(OpFromGraph):
|
415 | 416 | def update(self, node):
|
416 |
| - return {node.inputs[0]: node.inputs[0] + 1} |
| 417 | + return {node.inputs[0]: node.outputs[0]} |
417 | 418 |
|
418 |
| - dummy_inputs = [pt.scalar(), pt.scalar()] |
419 |
| - dummy_outputs = [pt.add(*dummy_inputs)] |
420 |
| - dummy_x = NonSymbolicRV(dummy_inputs, dummy_outputs)(pytensor.shared(1.0), 1.0) |
| 419 | + rng = pytensor.shared(np.random.default_rng()) |
| 420 | + dummy_rng = rng.type() |
| 421 | + dummy_next_rng, dummy_x = NonSymbolicRV( |
| 422 | + [dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs |
| 423 | + )(rng) |
421 | 424 |
|
422 | 425 | # Check that there are no updates at first
|
423 | 426 | fn = compile_pymc(inputs=[], outputs=dummy_x)
|
424 |
| - assert fn() == fn() == 2.0 |
| 427 | + assert fn() == fn() |
425 | 428 |
|
426 | 429 | # And they are enabled once the Op is registered as a SymbolicRV
|
427 | 430 | SymbolicRandomVariable.register(NonSymbolicRV)
|
428 |
| - fn = compile_pymc(inputs=[], outputs=dummy_x) |
429 |
| - assert fn() == 2.0 |
430 |
| - assert fn() == 3.0 |
| 431 | + fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431) |
| 432 | + assert fn() != fn() |
| 433 | + |
| 434 | + def test_compile_pymc_symbolic_rv_missing_update(self): |
| 435 | + """Test that error is raised if SymbolicRandomVariable Op does not |
| 436 | + provide rule for updating RNG""" |
| 437 | + |
| 438 | + class SymbolicRV(OpFromGraph): |
| 439 | + def update(self, node): |
| 440 | + # Update is provided for rng1 but not rng2 |
| 441 | + return {node.inputs[0]: node.outputs[0]} |
| 442 | + |
| 443 | + SymbolicRandomVariable.register(SymbolicRV) |
| 444 | + |
| 445 | + # No problems at first, as the one RNG is given the update rule |
| 446 | + rng1 = pytensor.shared(np.random.default_rng()) |
| 447 | + dummy_rng1 = rng1.type() |
| 448 | + dummy_next_rng1, dummy_x1 = SymbolicRV( |
| 449 | + [dummy_rng1], |
| 450 | + pt.random.normal(rng=dummy_rng1).owner.outputs, |
| 451 | + )(rng1) |
| 452 | + fn = compile_pymc(inputs=[], outputs=dummy_x1, random_seed=433) |
| 453 | + assert fn() != fn() |
| 454 | + |
| 455 | + # Now there's a problem as there is no update rule for rng2 |
| 456 | + rng2 = pytensor.shared(np.random.default_rng()) |
| 457 | + dummy_rng2 = rng2.type() |
| 458 | + dummy_next_rng1, dummy_x1, dummy_next_rng2, dummy_x2 = SymbolicRV( |
| 459 | + [dummy_rng1, dummy_rng2], |
| 460 | + [ |
| 461 | + *pt.random.normal(rng=dummy_rng1).owner.outputs, |
| 462 | + *pt.random.normal(rng=dummy_rng2).owner.outputs, |
| 463 | + ], |
| 464 | + )(rng1, rng2) |
| 465 | + with pytest.raises( |
| 466 | + ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable" |
| 467 | + ): |
| 468 | + compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2]) |
431 | 469 |
|
432 | 470 | def test_random_seed(self):
|
433 | 471 | seedx = pytensor.shared(np.random.default_rng(1))
|
@@ -457,15 +495,62 @@ def test_random_seed(self):
|
457 | 495 | assert y3_eval == y2_eval
|
458 | 496 |
|
459 | 497 | def test_multiple_updates_same_variable(self):
|
460 |
| - rng = pytensor.shared(np.random.default_rng(), name="rng") |
461 |
| - x = pt.random.normal(rng=rng) |
462 |
| - y = pt.random.normal(rng=rng) |
463 |
| - |
464 |
| - assert compile_pymc([], [x]) |
465 |
| - assert compile_pymc([], [y]) |
466 |
| - msg = "Multiple update expressions found for the variable rng" |
467 |
| - with pytest.raises(ValueError, match=msg): |
468 |
| - compile_pymc([], [x, y]) |
| 498 | + # Raise if unexpected warning is issued |
| 499 | + with warnings.catch_warnings(): |
| 500 | + warnings.simplefilter("error") |
| 501 | + |
| 502 | + rng = pytensor.shared(np.random.default_rng(), name="rng") |
| 503 | + x = pt.random.normal(rng=rng) |
| 504 | + y = pt.random.normal(rng=rng) |
| 505 | + |
| 506 | + # No warnings if only one variable is used |
| 507 | + assert compile_pymc([], [x]) |
| 508 | + assert compile_pymc([], [y]) |
| 509 | + |
| 510 | + user_warn_msg = "RNG Variable rng has multiple clients" |
| 511 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 512 | + f = compile_pymc([], [x, y], random_seed=456) |
| 513 | + assert f() == f() |
| 514 | + |
| 515 | + # The user can provide an explicit update, but we will still issue a warning |
| 516 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 517 | + f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
| 518 | + assert f() != f() |
| 519 | + |
| 520 | + # Same with default update |
| 521 | + rng.default_update = x.owner.outputs[0] |
| 522 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 523 | + f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
| 524 | + assert f() != f() |
| 525 | + |
| 526 | + def test_nested_updates(self): |
| 527 | + rng = pytensor.shared(np.random.default_rng()) |
| 528 | + next_rng1, x = pt.random.normal(rng=rng).owner.outputs |
| 529 | + next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs |
| 530 | + next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs |
| 531 | + |
| 532 | + collect_default_updates([], [x, y, z]) == {rng: next_rng3} |
| 533 | + |
| 534 | + fn = compile_pymc([], [x, y, z], random_seed=514) |
| 535 | + assert not set(list(np.array(fn()))) & set(list(np.array(fn()))) |
| 536 | + |
| 537 | + # A local myopic rule (as PyMC used before, would not work properly) |
| 538 | + fn = pytensor.function([], [x, y, z], updates={rng: next_rng1}) |
| 539 | + assert set(list(np.array(fn()))) & set(list(np.array(fn()))) |
| 540 | + |
| 541 | + |
| 542 | +def test_collect_default_updates_must_be_shared(): |
| 543 | + shared_rng = pytensor.shared(np.random.default_rng()) |
| 544 | + nonshared_rng = shared_rng.type() |
| 545 | + |
| 546 | + next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs |
| 547 | + next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs |
| 548 | + |
| 549 | + res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y]) |
| 550 | + assert res == {shared_rng: next_rng_of_shared} |
| 551 | + |
| 552 | + res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False) |
| 553 | + assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared} |
469 | 554 |
|
470 | 555 |
|
471 | 556 | def test_replace_rng_nodes():
|
|
0 commit comments