Skip to content

Commit c22ea96

Browse files
authored
Scope separator for netcdf (#5663)
* add a failing test * fix a failing test with changes in model.py * better test * change docs in model.py * fix tests * use :: saparator * fix typo
1 parent ab93967 commit c22ea96

File tree

4 files changed

+54
-29
lines changed

4 files changed

+54
-29
lines changed

pymc/model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def __init__(self, mean=0, sigma=1, name=''):
474474
475475
# 3) you can create variables with Var method
476476
self.Var('v1', Normal.dist(mu=mean, sigma=sd))
477-
# this will create variable named like '{prefix/}v1'
477+
# this will create variable named like '{prefix::}v1'
478478
# and assign attribute 'v1' to instance created
479479
# variable can be accessed with self.v1 or self['v1']
480480
@@ -516,7 +516,7 @@ def __init__(self, mean=0, sigma=1, name=''):
516516
CustomModel(mean=1, name='first')
517517
CustomModel(mean=2, name='second')
518518
519-
# variables inside both scopes will be named like `first/*`, `second/*`
519+
# variables inside both scopes will be named like `first::*`, `second::*`
520520
521521
"""
522522

@@ -538,14 +538,20 @@ def __new__(cls, *args, **kwargs):
538538
instance._aesara_config = kwargs.get("aesara_config", {})
539539
return instance
540540

541+
@staticmethod
542+
def _validate_name(name):
543+
if name.endswith(":"):
544+
raise KeyError("name should not end with `:`")
545+
return name
546+
541547
def __init__(
542548
self,
543549
name="",
544550
coords=None,
545551
check_bounds=True,
546552
rng_seeder: Optional[Union[int, np.random.RandomState]] = None,
547553
):
548-
self.name = name
554+
self.name = self._validate_name(name)
549555
self.check_bounds = check_bounds
550556

551557
if rng_seeder is None:
@@ -1462,25 +1468,27 @@ def prefix(self) -> str:
14621468
if self.isroot or not self.parent.prefix:
14631469
name = self.name
14641470
else:
1465-
name = f"{self.parent.prefix}/{self.name}"
1466-
return name.strip("/")
1471+
name = f"{self.parent.prefix}::{self.name}"
1472+
return name
14671473

14681474
def name_for(self, name):
14691475
"""Checks if name has prefix and adds if needed"""
1476+
name = self._validate_name(name)
14701477
if self.prefix:
14711478
if not name.startswith(self.prefix):
1472-
return f"{self.prefix}/{name}"
1479+
return f"{self.prefix}::{name}"
14731480
else:
14741481
return name
14751482
else:
14761483
return name
14771484

14781485
def name_of(self, name):
14791486
"""Checks if name has prefix and deletes if needed"""
1487+
name = self._validate_name(name)
14801488
if not self.prefix or not name:
14811489
return name
1482-
elif name.startswith(self.prefix + "/"):
1483-
return name[len(self.prefix) + 1 :]
1490+
elif name.startswith(self.prefix + "::"):
1491+
return name[len(self.prefix) + 2 :]
14841492
else:
14851493
return name
14861494

pymc/tests/test_data_container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ def test_data_naming():
399399
with pm.Model("named_model") as model:
400400
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
401401
y = pm.Normal("y")
402-
assert y.name == "named_model/y"
403-
assert x.name == "named_model/x"
402+
assert y.name == "named_model::y"
403+
assert x.name == "named_model::x"
404404

405405

406406
def test_get_data():

pymc/tests/test_model.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import aesara
1919
import aesara.sparse as sparse
2020
import aesara.tensor as at
21+
import arviz as az
2122
import cloudpickle
2223
import numpy as np
2324
import numpy.ma as ma
@@ -94,20 +95,20 @@ def test_context_passes_vars_to_parent_model(self):
9495
usermodel2.register_rv(pm.Normal.dist(), "v3")
9596
pm.Normal("v4")
9697
# this variable is created in parent model too
97-
assert "another/v2" in model.named_vars
98-
assert "another/v3" in model.named_vars
99-
assert "another/v3" in usermodel2.named_vars
100-
assert "another/v4" in model.named_vars
101-
assert "another/v4" in usermodel2.named_vars
98+
assert "another::v2" in model.named_vars
99+
assert "another::v3" in model.named_vars
100+
assert "another::v3" in usermodel2.named_vars
101+
assert "another::v4" in model.named_vars
102+
assert "another::v4" in usermodel2.named_vars
102103
assert hasattr(usermodel2, "v3")
103104
assert hasattr(usermodel2, "v2")
104105
assert hasattr(usermodel2, "v4")
105106
# When you create a class based model you should follow some rules
106107
with model:
107108
m = NewModel("one_more")
108-
assert m.d is model["one_more/d"]
109-
assert m["d"] is model["one_more/d"]
110-
assert m["one_more/d"] is model["one_more/d"]
109+
assert m.d is model["one_more::d"]
110+
assert m["d"] is model["one_more::d"]
111+
assert m["one_more::d"] is model["one_more::d"]
111112

112113

113114
class TestNested:
@@ -123,8 +124,8 @@ def test_nest_context_works(self):
123124
def test_named_context(self):
124125
with pm.Model() as m:
125126
NewModel(name="new")
126-
assert "new/v1" in m.named_vars
127-
assert "new/v2" in m.named_vars
127+
assert "new::v1" in m.named_vars
128+
assert "new::v2" in m.named_vars
128129

129130
def test_docstring_example1(self):
130131
usage1 = DocstringModel()
@@ -137,10 +138,10 @@ def test_docstring_example1(self):
137138
def test_docstring_example2(self):
138139
with pm.Model() as model:
139140
DocstringModel(name="prefix")
140-
assert "prefix/v1" in model.named_vars
141-
assert "prefix/v2" in model.named_vars
142-
assert "prefix/v3" in model.named_vars
143-
assert "prefix/v3_sq" in model.named_vars
141+
assert "prefix::v1" in model.named_vars
142+
assert "prefix::v2" in model.named_vars
143+
assert "prefix::v3" in model.named_vars
144+
assert "prefix::v3_sq" in model.named_vars
144145
assert len(model.potentials), 1
145146

146147
def test_duplicates_detection(self):
@@ -160,14 +161,30 @@ def test_nested_named_model_repeated(self):
160161
b = pm.Normal("var")
161162
with pm.Model("sub"):
162163
b = pm.Normal("var")
163-
assert {"sub/var", "sub/sub/var"} == set(model.named_vars.keys())
164+
assert {"sub::var", "sub::sub::var"} == set(model.named_vars.keys())
164165

165166
def test_nested_named_model(self):
166167
with pm.Model("sub1") as model:
167168
b = pm.Normal("var")
168169
with pm.Model("sub2"):
169170
b = pm.Normal("var")
170-
assert {"sub1/var", "sub1/sub2/var"} == set(model.named_vars.keys())
171+
assert {"sub1::var", "sub1::sub2::var"} == set(model.named_vars.keys())
172+
173+
def test_nested_model_to_netcdf(self, tmp_path):
174+
with pm.Model("scope") as model:
175+
b = pm.Normal("var")
176+
trace = pm.sample(100, tune=0)
177+
az.to_netcdf(trace, tmp_path / "trace.nc")
178+
trace1 = az.from_netcdf(tmp_path / "trace.nc")
179+
assert "scope::var" in trace1.posterior
180+
181+
def test_bad_name(self):
182+
with pm.Model() as model:
183+
with pytest.raises(KeyError):
184+
b = pm.Normal("var::")
185+
with pytest.raises(KeyError):
186+
with pm.Model("scope::") as model:
187+
b = pm.Normal("v")
171188

172189

173190
class TestObserved:

pymc/tests/test_smc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,9 @@ def test_named_model(self):
534534
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)
535535

536536
trace = pm.sample_smc(draws=10, chains=2, return_inferencedata=False)
537-
assert f"{name}/a" in trace.varnames
538-
assert f"{name}/b" in trace.varnames
539-
assert f"{name}/b_log__" in trace.varnames
537+
assert f"{name}::a" in trace.varnames
538+
assert f"{name}::b" in trace.varnames
539+
assert f"{name}::b_log__" in trace.varnames
540540

541541

542542
class TestMHKernel(SeededTest):

0 commit comments

Comments
 (0)