Skip to content

Commit 5e82394

Browse files
bomtalltwiecki
bomtall
authored andcommitted
move pytest imports to be local
1 parent 33b24cc commit 5e82394

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pymc/testing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import numpy as np
2222
import pytensor
2323
import pytensor.tensor as pt
24-
import pytest
2524

2625
from numpy import random as nr
2726
from numpy import testing as npt
@@ -342,6 +341,8 @@ def check_logp(
342341
scipy_args : Dictionary with extra arguments needed to call scipy logp method
343342
Usually the same as extra_args
344343
"""
344+
import pytest
345+
345346
if decimal is None:
346347
decimal = select_by_precision(float64=6, float32=3)
347348

@@ -388,6 +389,7 @@ def scipy_logp_with_scipy_args(**args):
388389
point[invalid_param] = np.asarray(
389390
invalid_edge, dtype=paramdomains[invalid_param].dtype
390391
)
392+
391393
with pytest.raises(ParameterValueError):
392394
pymc_logp(**point)
393395
pytest.fail(f"test_params={point}")
@@ -459,6 +461,8 @@ def check_logcdf(
459461
returns -inf for invalid parameter values outside the supported domain edge
460462
461463
"""
464+
import pytest
465+
462466
if decimal is None:
463467
decimal = select_by_precision(float64=6, float32=3)
464468

@@ -498,6 +502,7 @@ def check_logcdf(
498502

499503
point = valid_params.copy()
500504
point[invalid_param] = invalid_edge
505+
501506
with pytest.raises(ParameterValueError):
502507
pymc_logcdf(**point)
503508
pytest.fail(f"test_params={point}")
@@ -563,6 +568,8 @@ def check_icdf(
563568
returns nan for invalid parameter values outside the supported domain edge
564569
565570
"""
571+
import pytest
572+
566573
if decimal is None:
567574
decimal = select_by_precision(float64=6, float32=3)
568575

@@ -601,6 +608,7 @@ def check_icdf(
601608

602609
point = valid_params.copy()
603610
point[invalid_param] = invalid_edge
611+
604612
with pytest.raises(ParameterValueError):
605613
pymc_icdf(**point)
606614
pytest.fail(f"test_params={point}")
@@ -860,6 +868,8 @@ class BaseTestDistributionRandom:
860868
random_state = None
861869

862870
def test_distribution(self):
871+
import pytest
872+
863873
self.validate_tests_list()
864874
if self.pymc_dist == pm.Wishart:
865875
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):

0 commit comments

Comments
 (0)