Skip to content

Commit ff7c2ff

Browse files
committed
Do not propagate dims to observed component of imputed variable
1 parent 95a0ef1 commit ff7c2ff

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ def make_obs_var(
14501450
observed_rv_var.tag.observations = nonmissing_data
14511451

14521452
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
1453-
self.add_random_variable(observed_rv_var, dims)
1453+
self.add_random_variable(observed_rv_var)
14541454
self.observed_RVs.append(observed_rv_var)
14551455

14561456
# Create deterministic that combines observed and missing

pymc/tests/test_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,18 @@ def test_missing_symmetric(self):
13651365
assert x_obs_vv in logp_inputs
13661366
assert x_unobs_vv in logp_inputs
13671367

1368+
def test_dims(self):
1369+
"""Test that we don't propagate dims to the subcomponents of a partially
1370+
observed RV
1371+
1372+
See https://github.com/pymc-devs/pymc/issues/6177
1373+
"""
1374+
data = np.array([np.nan] * 3 + [0] * 7)
1375+
with pm.Model(coords={"observed": range(10)}) as model:
1376+
with pytest.warns(ImputationWarning):
1377+
x = pm.Normal("x", observed=data, dims=("observed",))
1378+
assert model.RV_dims == {"x": ("observed",)}
1379+
13681380

13691381
class TestShared(SeededTest):
13701382
def test_deterministic(self):

0 commit comments

Comments
 (0)