Skip to content

Commit b5cf8ad

Browse files
authored
fix unitary case in prior from trace (#67)
* fix unitary case * fix pre commit
1 parent 6494930 commit b5cf8ad

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ pytestdebug.log
4444

4545
# Codespaces
4646
pythonenv*
47-
env/
47+
env/

pymc_experimental/tests/test_prior_from_trace.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,21 @@ def coords():
3333
return dict(test=range(3), simplex=range(4))
3434

3535

36-
@pytest.fixture
37-
def user_param_cfg():
38-
return ("t",), dict(
39-
a="d",
40-
b=dict(transform=transforms.log, dims=("test",)),
41-
c=dict(transform=transforms.simplex, dims=("simplex",)),
42-
)
36+
@pytest.fixture(
37+
params=[
38+
[
39+
("t",),
40+
dict(
41+
a="d",
42+
b=dict(transform=transforms.log, dims=("test",)),
43+
c=dict(transform=transforms.simplex, dims=("simplex",)),
44+
),
45+
],
46+
[("t",), dict()],
47+
]
48+
)
49+
def user_param_cfg(request):
50+
return request.param
4351

4452

4553
@pytest.fixture

pymc_experimental/utils/prior.py

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo:
9494
def _mean_chol(flat_array: np.ndarray):
9595
mean = flat_array.mean(0)
9696
cov = np.cov(flat_array, rowvar=False)
97+
cov = np.atleast_2d(cov)
9798
chol = np.linalg.cholesky(cov)
9899
return mean, chol
99100

0 commit comments

Comments
 (0)