|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -import warnings |
15 | 14 |
|
16 | 15 | import numpy as np
|
17 | 16 | import numpy.ma as ma
|
@@ -486,34 +485,45 @@ def test_random_seed(self):
|
486 | 485 | assert x3_eval == x2_eval
|
487 | 486 | assert y3_eval == y2_eval
|
488 | 487 |
|
| 488 | + @pytest.mark.filterwarnings("error") # This is part of the test |
489 | 489 | def test_multiple_updates_same_variable(self):
|
490 |
| - # Raise if unexpected warning is issued |
491 |
| - with warnings.catch_warnings(): |
492 |
| - warnings.simplefilter("error") |
493 |
| - |
494 |
| - rng = pytensor.shared(np.random.default_rng(), name="rng") |
495 |
| - x = pt.random.normal(rng=rng) |
496 |
| - y = pt.random.normal(rng=rng) |
497 |
| - |
498 |
| - # No warnings if only one variable is used |
499 |
| - assert compile_pymc([], [x]) |
500 |
| - assert compile_pymc([], [y]) |
501 |
| - |
502 |
| - user_warn_msg = "RNG Variable rng has multiple clients" |
503 |
| - with pytest.warns(UserWarning, match=user_warn_msg): |
504 |
| - f = compile_pymc([], [x, y], random_seed=456) |
505 |
| - assert f() == f() |
506 |
| - |
507 |
| - # The user can provide an explicit update, but we will still issue a warning |
508 |
| - with pytest.warns(UserWarning, match=user_warn_msg): |
509 |
| - f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
510 |
| - assert f() != f() |
511 |
| - |
512 |
| - # Same with default update |
513 |
| - rng.default_update = x.owner.outputs[0] |
514 |
| - with pytest.warns(UserWarning, match=user_warn_msg): |
515 |
| - f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
516 |
| - assert f() != f() |
| 490 | + rng = pytensor.shared(np.random.default_rng(), name="rng") |
| 491 | + x = pt.random.normal(0, rng=rng) |
| 492 | + y = pt.random.normal(1, rng=rng) |
| 493 | + |
| 494 | + # No warnings if only one variable is used |
| 495 | + assert compile_pymc([], [x]) |
| 496 | + assert compile_pymc([], [y]) |
| 497 | + |
| 498 | + user_warn_msg = "RNG Variable rng has multiple distinct clients" |
| 499 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 500 | + f = compile_pymc([], [x, y], random_seed=456) |
| 501 | + assert f() == f() |
| 502 | + |
| 503 | + # The user can provide an explicit update, but we will still issue a warning |
| 504 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 505 | + f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
| 506 | + assert f() != f() |
| 507 | + |
| 508 | + # Same with default update |
| 509 | + rng.default_update = x.owner.outputs[0] |
| 510 | + with pytest.warns(UserWarning, match=user_warn_msg): |
| 511 | + f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456) |
| 512 | + assert f() != f() |
| 513 | + |
| 514 | + @pytest.mark.filterwarnings("error") # This is part of the test |
| 515 | + def test_duplicated_client_nodes(self): |
| 516 | + """Test compile_pymc can handle duplicated (mergeable) RV updates.""" |
| 517 | + rng = pytensor.shared(np.random.default_rng(1)) |
| 518 | + x = pt.random.normal(rng=rng) |
| 519 | + y = x.owner.clone().default_output() |
| 520 | + |
| 521 | + fn = compile_pymc([], [x, y], random_seed=1) |
| 522 | + res_x1, res_y1 = fn() |
| 523 | + assert res_x1 == res_y1 |
| 524 | + res_x2, res_y2 = fn() |
| 525 | + assert res_x2 == res_y2 |
| 526 | + assert res_x1 != res_x2 |
517 | 527 |
|
518 | 528 | def test_nested_updates(self):
|
519 | 529 | rng = pytensor.shared(np.random.default_rng())
|
|
0 commit comments