|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import itertools
|
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | from typing import Union
|
18 | 19 |
|
|
24 | 25 |
|
25 | 26 | from pymc.model import Model
|
26 | 27 |
|
27 |
| -# from pymc.distributions.discrete import UnmeasurableConstantRV |
28 |
| - |
29 | 28 | __all__ = [
|
30 | 29 | "str_for_dist",
|
31 | 30 | "str_for_symbolic_dist",
|
@@ -65,7 +64,15 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
|
65 | 64 | if var.name:
|
66 | 65 | return var.name
|
67 | 66 | if isinstance(var, TensorConstant):
|
68 |
| - return _str_for_constant(var, formatting, print_vector=True) |
| 67 | + if len(var.data.shape) > 1: |
| 68 | + raise NotImplementedError |
| 69 | + try: |
| 70 | + if var.data.shape[0] > 1: |
| 71 | + # weights in mixture model |
| 72 | + return "[" + ",".join([str(weight) for weight in var.data]) + "]" |
| 73 | + except IndexError: |
| 74 | + # just a scalar |
| 75 | + return _str_for_constant(var, formatting) |
69 | 76 | if isinstance(var.owner.op, MakeVector):
|
70 | 77 | # psi in some zero inflated distribution
|
71 | 78 | return dispatch_comp_str(var.owner.inputs[1])
|
@@ -102,11 +109,17 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
|
102 | 109 | dist_parameters = rv.owner.inputs[1:]
|
103 | 110 |
|
104 | 111 | elif "Censored" in rv.owner.op._print_name[0]:
|
105 |
| - dist_parameters = rv.owner.inputs[2:] |
| 112 | + dist_parameters = rv.owner.inputs |
106 | 113 | else:
|
107 | 114 | # Latex representation for the SymbolicDistribution has not been implemented.
|
108 | 115 | # Hoping for the best here!
|
109 | 116 | dist_parameters = rv.owner.inputs[2:]
|
| 117 | + warnings.warn( |
| 118 | + "Latex representation for this SymbolicDistribution has not been implemented. " |
| 119 | + "Please have a look at str_for_symbolic_dist in pymc/printing.py", |
| 120 | + FutureWarning, |
| 121 | + stacklevel=2, |
| 122 | + ) |
110 | 123 |
|
111 | 124 | dist_args = [
|
112 | 125 | dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params)
|
@@ -222,13 +235,11 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
|
222 | 235 | return _str
|
223 | 236 |
|
224 | 237 |
|
225 |
| -def _str_for_constant(var: TensorConstant, formatting: str, print_vector: bool = False) -> str: |
| 238 | +def _str_for_constant(var: TensorConstant, formatting: str) -> str: |
226 | 239 | if len(var.data.shape) == 0:
|
227 | 240 | return f"{var.data:.3g}"
|
228 | 241 | elif len(var.data.shape) == 1 and var.data.shape[0] == 1:
|
229 | 242 | return f"{var.data[0]:.3g}"
|
230 |
| - elif len(var.data.shape) == 1 and print_vector: |
231 |
| - return "[" + ", ".join([f"{const:.3g}" for const in var.data]) + "]" |
232 | 243 | elif "latex" in formatting:
|
233 | 244 | return r"\text{<constant>}"
|
234 | 245 | else:
|
|
0 commit comments