Skip to content

Commit 8d047fc

Browse files
committed
use simplex transform for the test case
1 parent 8ea35db commit 8d047fc

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc_experimental/tests/test_prior_from_trace.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ def test_parsing_arguments(case):
2828

2929
@pytest.fixture
3030
def coords():
31-
return dict(test=range(3))
31+
return dict(test=range(3), simplex=range(4))
3232

3333

3434
@pytest.fixture
3535
def param_cfg():
3636
return dict(
3737
a=pmx.utils.prior.arg_to_param_cfg("a"),
38-
b=pmx.utils.prior.arg_to_param_cfg(
39-
"b", dict(transform=transforms.sum_to_1, dims=("test",))
38+
b=pmx.utils.prior.arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))),
39+
c=pmx.utils.prior.arg_to_param_cfg(
40+
"c", dict(transform=transforms.simplex, dims=("simplex",))
4041
),
41-
c=pmx.utils.prior.arg_to_param_cfg("c", dict(transform=transforms.log, dims=("test",))),
4242
)
4343

4444

@@ -55,6 +55,7 @@ def idata(param_cfg, coords):
5555
var = cfg["transform"].backward(orig).eval()
5656
else:
5757
var = orig
58+
assert not np.isnan(var).any()
5859
vars[k] = var
5960
return az.convert_to_inference_data(vars, coords=coords)
6061

0 commit comments

Comments
 (0)