Skip to content

Commit 059ce45

Browse files
michaelosthegericardoV94
authored andcommitted
Return VarName type from get_var_name function
1 parent 41f0181 commit 059ce45

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

pymc/model/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from pymc.initial_point import make_initial_point_fn
6767
from pymc.logprob.basic import transformed_conditional_logp
6868
from pymc.logprob.utils import ParameterValueError
69-
from pymc.model_graph import VarName, model_to_graphviz
69+
from pymc.model_graph import model_to_graphviz
7070
from pymc.pytensorf import (
7171
PointFunc,
7272
SeedSequenceSeed,
@@ -80,6 +80,7 @@
8080
)
8181
from pymc.util import (
8282
UNSET,
83+
VarName,
8384
WithMemoization,
8485
_add_future_warning_tag,
8586
get_transformed_name,
@@ -2061,7 +2062,7 @@ def compile_fn(
20612062
)
20622063

20632064

2064-
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
2065+
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[VarName, np.ndarray]:
20652066
"""Build a point. Uses same args as dict() does.
20662067
Filters out variables not in the model. All keys are strings.
20672068

pymc/model_graph.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515

1616
from collections import defaultdict
17-
from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set
17+
from typing import Dict, Iterable, List, Optional, Sequence, Set
1818

1919
from pytensor import function
2020
from pytensor.compile.sharedvalue import SharedVariable
@@ -28,10 +28,7 @@
2828

2929
import pymc as pm
3030

31-
from pymc.util import get_default_varnames, get_var_name
32-
33-
VarName = NewType("VarName", str)
34-
31+
from pymc.util import VarName, get_default_varnames, get_var_name
3532

3633
__all__ = (
3734
"ModelGraph",
@@ -113,7 +110,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
113110
selected_ancestors.add(self.model.rvs_to_values[var])
114111

115112
# ordering of self._all_var_names is important
116-
return [VarName(get_var_name(var)) for var in selected_ancestors]
113+
return [get_var_name(var) for var in selected_ancestors]
117114

118115
def make_compute_graph(
119116
self, var_names: Optional[Iterable[VarName]] = None

pymc/util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import functools
1616
import warnings
1717

18-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
18+
from typing import Any, Dict, List, NewType, Optional, Sequence, Tuple, Union, cast
1919

2020
import arviz
2121
import cloudpickle
@@ -29,6 +29,8 @@
2929

3030
from pymc.exceptions import BlockModelAccessError
3131

32+
VarName = NewType("VarName", str)
33+
3234

3335
class _UnsetType:
3436
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
@@ -207,9 +209,9 @@ def get_default_varnames(var_iterator, include_transformed):
207209
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
208210

209211

210-
def get_var_name(var) -> str:
212+
def get_var_name(var) -> VarName:
211213
"""Get an appropriate, plain variable name for a variable."""
212-
return str(getattr(var, "name", var))
214+
return VarName(str(getattr(var, "name", var)))
213215

214216

215217
def get_transformed(z):

0 commit comments

Comments
 (0)