|
21 | 21 | import numpy as np
|
22 | 22 | import pytensor
|
23 | 23 | import pytensor.tensor as pt
|
24 |
| -import pytest |
25 | 24 |
|
26 | 25 | from numpy import random as nr
|
27 | 26 | from numpy import testing as npt
|
@@ -342,6 +341,8 @@ def check_logp(
|
342 | 341 | scipy_args : Dictionary with extra arguments needed to call scipy logp method
|
343 | 342 | Usually the same as extra_args
|
344 | 343 | """
|
| 344 | + import pytest |
| 345 | + |
345 | 346 | if decimal is None:
|
346 | 347 | decimal = select_by_precision(float64=6, float32=3)
|
347 | 348 |
|
@@ -388,6 +389,7 @@ def scipy_logp_with_scipy_args(**args):
|
388 | 389 | point[invalid_param] = np.asarray(
|
389 | 390 | invalid_edge, dtype=paramdomains[invalid_param].dtype
|
390 | 391 | )
|
| 392 | + |
391 | 393 | with pytest.raises(ParameterValueError):
|
392 | 394 | pymc_logp(**point)
|
393 | 395 | pytest.fail(f"test_params={point}")
|
@@ -459,6 +461,8 @@ def check_logcdf(
|
459 | 461 | returns -inf for invalid parameter values outside the supported domain edge
|
460 | 462 |
|
461 | 463 | """
|
| 464 | + import pytest |
| 465 | + |
462 | 466 | if decimal is None:
|
463 | 467 | decimal = select_by_precision(float64=6, float32=3)
|
464 | 468 |
|
@@ -498,6 +502,7 @@ def check_logcdf(
|
498 | 502 |
|
499 | 503 | point = valid_params.copy()
|
500 | 504 | point[invalid_param] = invalid_edge
|
| 505 | + |
501 | 506 | with pytest.raises(ParameterValueError):
|
502 | 507 | pymc_logcdf(**point)
|
503 | 508 | pytest.fail(f"test_params={point}")
|
@@ -563,6 +568,8 @@ def check_icdf(
|
563 | 568 | returns nan for invalid parameter values outside the supported domain edge
|
564 | 569 |
|
565 | 570 | """
|
| 571 | + import pytest |
| 572 | + |
566 | 573 | if decimal is None:
|
567 | 574 | decimal = select_by_precision(float64=6, float32=3)
|
568 | 575 |
|
@@ -601,6 +608,7 @@ def check_icdf(
|
601 | 608 |
|
602 | 609 | point = valid_params.copy()
|
603 | 610 | point[invalid_param] = invalid_edge
|
| 611 | + |
604 | 612 | with pytest.raises(ParameterValueError):
|
605 | 613 | pymc_icdf(**point)
|
606 | 614 | pytest.fail(f"test_params={point}")
|
@@ -860,6 +868,8 @@ class BaseTestDistributionRandom:
|
860 | 868 | random_state = None
|
861 | 869 |
|
862 | 870 | def test_distribution(self):
|
| 871 | + import pytest |
| 872 | + |
863 | 873 | self.validate_tests_list()
|
864 | 874 | if self.pymc_dist == pm.Wishart:
|
865 | 875 | with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
|
|
0 commit comments