Skip to content

Commit 60c3407

Browse files
working on weight representation for mixtures
1 parent c35c676 commit 60c3407

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

pymc/printing.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import itertools
16+
import warnings
1617

1718
from typing import Union
1819

@@ -24,8 +25,6 @@
2425

2526
from pymc.model import Model
2627

27-
# from pymc.distributions.discrete import UnmeasurableConstantRV
28-
2928
__all__ = [
3029
"str_for_dist",
3130
"str_for_symbolic_dist",
@@ -65,7 +64,15 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
6564
if var.name:
6665
return var.name
6766
if isinstance(var, TensorConstant):
68-
return _str_for_constant(var, formatting, print_vector=True)
67+
if len(var.data.shape) > 1:
68+
raise NotImplementedError
69+
try:
70+
if var.data.shape[0] > 1:
71+
# weights in mixture model
72+
return "[" + ",".join([str(weight) for weight in var.data]) + "]"
73+
except IndexError:
74+
# just a scalar
75+
return _str_for_constant(var, formatting)
6976
if isinstance(var.owner.op, MakeVector):
7077
# psi in some zero inflated distribution
7178
return dispatch_comp_str(var.owner.inputs[1])
@@ -102,11 +109,17 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
102109
dist_parameters = rv.owner.inputs[1:]
103110

104111
elif "Censored" in rv.owner.op._print_name[0]:
105-
dist_parameters = rv.owner.inputs[2:]
112+
dist_parameters = rv.owner.inputs
106113
else:
107114
# Latex representation for the SymbolicDistribution has not been implemented.
108115
# Hoping for the best here!
109116
dist_parameters = rv.owner.inputs[2:]
117+
warnings.warn(
118+
"Latex representation for this SymbolicDistribution has not been implemented. "
119+
"Please have a look at str_for_symbolic_dist in pymc/printing.py",
120+
FutureWarning,
121+
stacklevel=2,
122+
)
110123

111124
dist_args = [
112125
dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params)
@@ -222,13 +235,11 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
222235
return _str
223236

224237

225-
def _str_for_constant(var: TensorConstant, formatting: str, print_vector: bool = False) -> str:
238+
def _str_for_constant(var: TensorConstant, formatting: str) -> str:
226239
if len(var.data.shape) == 0:
227240
return f"{var.data:.3g}"
228241
elif len(var.data.shape) == 1 and var.data.shape[0] == 1:
229242
return f"{var.data[0]:.3g}"
230-
elif len(var.data.shape) == 1 and print_vector:
231-
return "[" + ", ".join([f"{const:.3g}" for const in var.data]) + "]"
232243
elif "latex" in formatting:
233244
return r"\text{<constant>}"
234245
else:

0 commit comments

Comments
 (0)