Skip to content

Commit 352dc7d

Browse files
Fix and add tests for Theano printing
1 parent 3a88917 commit 352dc7d

File tree

4 files changed

+236
-28
lines changed

4 files changed

+236
-28
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-e ./
2+
sympy>=1.3
23
coveralls
34
pydocstyle>=3.0.0
45
pytest>=5.0.0

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def get_long_description():
4545
"etuples>=0.3.1",
4646
"cons>=0.4.0",
4747
"toolz>=0.9.0",
48-
"sympy>=1.3",
4948
"cachetools",
5049
"pymc3>=3.6",
5150
"pymc4 @ git+https://github.com/pymc-devs/pymc4.git@master#egg=pymc4-0.0.1",

symbolic_pymc/theano/printing.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,19 @@
1010

1111
from theano import gof
1212

13-
from sympy import Array as SympyArray
14-
from sympy.printing import latex as sympy_latex
13+
try:
14+
from sympy import Array as SympyArray
15+
from sympy.printing import latex as sympy_latex
16+
17+
def latex_print_array(data): # pragma: no cover
18+
return sympy_latex(SympyArray(data))
19+
20+
21+
except ImportError: # pragma: no cover
22+
23+
def latex_print_array(data):
24+
return data
25+
1526

1627
from .opt import FunctionGraph
1728
from .ops import RandomVariable
@@ -60,16 +71,16 @@ def process_param(self, idx, sform, pstate):
6071
The printer state.
6172
6273
"""
63-
return sform
74+
return sform # pragma: no cover
6475

6576
def process(self, output, pstate):
6677
if output in pstate.memo:
6778
return pstate.memo[output]
6879

6980
pprinter = pstate.pprinter
70-
node = output.owner
81+
node = getattr(output, "owner", None)
7182

72-
if node is None or not isinstance(node.op, RandomVariable):
83+
if node is None or not isinstance(node.op, RandomVariable): # pragma: no cover
7384
raise TypeError(
7485
"Function %s cannot represent a variable that is "
7586
"not the result of a RandomVariable operation" % self.name
@@ -78,7 +89,7 @@ def process(self, output, pstate):
7889
op_name = self.name or getattr(node.op, "print_name", None)
7990
op_name = op_name or getattr(node.op, "name", None)
8091

81-
if op_name is None:
92+
if op_name is None: # pragma: no cover
8293
raise ValueError(f"Could not find a name for {node.op}")
8394

8495
# Allow `Op`s to specify their ascii and LaTeX formats (in a tuple/list
@@ -144,7 +155,7 @@ def process(self, output, pstate):
144155

145156
class GenericSubtensorPrinter(object):
146157
def process(self, r, pstate):
147-
if r.owner is None:
158+
if getattr(r, "owner", None) is None: # pragma: no cover
148159
raise TypeError("Can only print Subtensor.")
149160

150161
output_latex = getattr(pstate, "latex", False)
@@ -161,13 +172,13 @@ def process(self, r, pstate):
161172
if isinstance(entry, slice):
162173
s_parts = [""] * 2
163174
if entry.start is not None:
164-
s_parts[0] = entry.start
175+
s_parts[0] = pstate.pprinter.process(inputs.pop())
165176

166177
if entry.stop is not None:
167-
s_parts[1] = entry.stop
178+
s_parts[1] = pstate.pprinter.process(inputs.pop())
168179

169180
if entry.step is not None:
170-
s_parts.append(entry.stop)
181+
s_parts.append(pstate.pprinter.process(inputs.pop()))
171182

172183
sidxs.append(":".join(s_parts))
173184
else:
@@ -215,16 +226,22 @@ def process(cls, output, pstate):
215226
using_latex = getattr(pstate, "latex", False)
216227
# Crude--but effective--means of stopping print-outs for large
217228
# arrays.
218-
constant = isinstance(output, tt.TensorConstant)
229+
constant = isinstance(output, (tt.TensorConstant, theano.scalar.basic.ScalarConstant))
219230
too_large = constant and (output.data.size > cls.max_line_width * cls.max_line_height)
220231

221232
if constant and not too_large:
222233
# Print constants that aren't too large
223234
if using_latex and output.ndim > 0:
224-
out_name = sympy_latex(SympyArray(output.data))
235+
out_name = latex_print_array(output.data)
225236
else:
226237
out_name = str(output.data)
227-
elif isinstance(output, tt.TensorVariable) or constant:
238+
elif (
239+
isinstance(
240+
output,
241+
(tt.TensorVariable, theano.scalar.basic.Scalar, theano.scalar.basic.ScalarVariable),
242+
)
243+
or constant
244+
):
228245
# Process name and shape
229246

230247
# Attempt to get the original variable, in case this is a cloned
@@ -238,7 +255,7 @@ def process(cls, output, pstate):
238255

239256
shape_strings = pstate.preamble_dict.setdefault("shape_strings", OrderedDict())
240257
shape_strings[output] = shape_info
241-
else:
258+
else: # pragma: no cover
242259
raise TypeError(f"Type {type(output)} not handled by variable printer")
243260

244261
pstate.memo[output] = out_name
@@ -268,7 +285,7 @@ def process_variable_name(cls, output, pstate):
268285
_ = [available_names.pop(v.name, None) for v in fgraph.variables]
269286
setattr(pstate, "available_names", available_names)
270287

271-
if output.name:
288+
if getattr(output, "name", None):
272289
# Observed an existing name; remove it.
273290
out_name = output.name
274291
available_names.pop(out_name, None)
@@ -524,11 +541,18 @@ def __call__(self, *args, latex_env="equation", latex_label=None):
524541

525542
# The order here is important!
526543
tt_pprint.printers.insert(
527-
0, (lambda pstate, r: isinstance(r, tt.Variable), VariableWithShapePrinter)
544+
0,
545+
(
546+
lambda pstate, r: isinstance(r, (theano.scalar.basic.Scalar, tt.Variable)),
547+
VariableWithShapePrinter,
548+
),
528549
)
529550
tt_pprint.printers.insert(
530551
0,
531-
(lambda pstate, r: r.owner and isinstance(r.owner.op, RandomVariable), RandomVariablePrinter()),
552+
(
553+
lambda pstate, r: getattr(r, "owner", None) and isinstance(r.owner.op, RandomVariable),
554+
RandomVariablePrinter(),
555+
),
532556
)
533557

534558

@@ -538,9 +562,9 @@ def process(self, output, pstate):
538562
return pstate.memo[output]
539563

540564
pprinter = pstate.pprinter
541-
node = output.owner
565+
node = getattr(output, "owner", None)
542566

543-
if node is None or not isinstance(node.op, Observed):
567+
if node is None or not isinstance(node.op, Observed): # pragma: no cover
544568
raise TypeError(f"Node Op is not of type `Observed`: {node.op}")
545569

546570
val = node.inputs[0]

tests/theano/test_printing.py

Lines changed: 192 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,211 @@
1+
import textwrap
2+
13
import theano.tensor as tt
24

3-
from symbolic_pymc.theano.random_variables import NormalRV
4-
from symbolic_pymc.theano.printing import tt_pprint
5+
from symbolic_pymc.theano.random_variables import NormalRV, observed
6+
from symbolic_pymc.theano.printing import tt_pprint, tt_tprint
57

68

79
def test_notex_print():
810

911
tt_normalrv_noname_expr = tt.scalar("b") * NormalRV(tt.scalar("\\mu"), tt.scalar("\\sigma"))
10-
expected = "b in R, \\mu in R, \\sigma in R\na ~ N(\\mu, \\sigma**2) in R\n(b * a)"
11-
assert tt_pprint(tt_normalrv_noname_expr) == expected
12+
expected = textwrap.dedent(
13+
r"""
14+
b in R, \mu in R, \sigma in R
15+
a ~ N(\mu, \sigma**2) in R
16+
(b * a)
17+
"""
18+
)
19+
assert tt_pprint(tt_normalrv_noname_expr) == expected.strip()
1220

1321
# Make sure the constant shape is show in values and not symbols.
1422
tt_normalrv_name_expr = tt.scalar("b") * NormalRV(
1523
tt.scalar("\\mu"), tt.scalar("\\sigma"), size=[2, 1], name="X"
1624
)
17-
expected = "b in R, \\mu in R, \\sigma in R\nX ~ N(\\mu, \\sigma**2) in R**(2 x 1)\n(b * X)"
18-
assert tt_pprint(tt_normalrv_name_expr) == expected
25+
expected = textwrap.dedent(
26+
r"""
27+
b in R, \mu in R, \sigma in R
28+
X ~ N(\mu, \sigma**2) in R**(2 x 1)
29+
(b * X)
30+
"""
31+
)
32+
assert tt_pprint(tt_normalrv_name_expr) == expected.strip()
33+
34+
tt_2_normalrv_noname_expr = tt.matrix("M") * NormalRV(
35+
tt.scalar("\\mu_2"), tt.scalar("\\sigma_2")
36+
)
37+
tt_2_normalrv_noname_expr *= tt.scalar("b") * NormalRV(
38+
tt_2_normalrv_noname_expr, tt.scalar("\\sigma")
39+
) + tt.scalar("c")
40+
expected = textwrap.dedent(
41+
r"""
42+
M in R**(N^M_0 x N^M_1), \mu_2 in R, \sigma_2 in R
43+
b in R, \sigma in R, c in R
44+
a ~ N(\mu_2, \sigma_2**2) in R, d ~ N((M * a), \sigma**2) in R**(N^d_0 x N^d_1)
45+
((M * a) * ((b * d) + c))
46+
"""
47+
)
48+
assert tt_pprint(tt_2_normalrv_noname_expr) == expected.strip()
49+
50+
expected = textwrap.dedent(
51+
r"""
52+
b in Z, c in Z, M in R**(N^M_0 x N^M_1)
53+
M[b, c]
54+
"""
55+
)
56+
# TODO: "c" should be "1".
57+
assert (
58+
tt_pprint(tt.matrix("M")[tt.iscalar("a"), tt.constant(1, dtype="int")]) == expected.strip()
59+
)
60+
61+
expected = textwrap.dedent(
62+
r"""
63+
M in R**(N^M_0 x N^M_1)
64+
M[1]
65+
"""
66+
)
67+
assert tt_pprint(tt.matrix("M")[1]) == expected.strip()
68+
69+
expected = textwrap.dedent(
70+
r"""
71+
M in N**(N^M_0)
72+
M[2:4:0]
73+
"""
74+
)
75+
assert tt_pprint(tt.vector("M", dtype="uint32")[0:4:2]) == expected.strip()
76+
77+
norm_rv = NormalRV(tt.scalar("\\mu"), tt.scalar("\\sigma"))
78+
rv_obs = observed(tt.constant(1.0, dtype=norm_rv.dtype), norm_rv)
79+
80+
expected = textwrap.dedent(
81+
r"""
82+
\mu in R, \sigma in R
83+
a ~ N(\mu, \sigma**2) in R
84+
a = 1.0
85+
"""
86+
)
87+
assert tt_pprint(rv_obs) == expected.strip()
88+
89+
90+
def test_tex_print():
91+
92+
tt_normalrv_noname_expr = tt.scalar("b") * NormalRV(tt.scalar("\\mu"), tt.scalar("\\sigma"))
93+
expected = textwrap.dedent(
94+
r"""
95+
\begin{equation}
96+
\begin{gathered}
97+
b \in \mathbb{R}, \,\mu \in \mathbb{R}, \,\sigma \in \mathbb{R}
98+
\\
99+
a \sim \operatorname{N}\left(\mu, {\sigma}^{2}\right)\, \in \mathbb{R}
100+
\end{gathered}
101+
\\
102+
(b \odot a)
103+
\end{equation}
104+
"""
105+
)
106+
assert tt_tprint(tt_normalrv_noname_expr) == expected.strip()
107+
108+
tt_normalrv_name_expr = tt.scalar("b") * NormalRV(
109+
tt.scalar("\\mu"), tt.scalar("\\sigma"), size=[2, 1], name="X"
110+
)
111+
expected = textwrap.dedent(
112+
r"""
113+
\begin{equation}
114+
\begin{gathered}
115+
b \in \mathbb{R}, \,\mu \in \mathbb{R}, \,\sigma \in \mathbb{R}
116+
\\
117+
X \sim \operatorname{N}\left(\mu, {\sigma}^{2}\right)\, \in \mathbb{R}^{2 \times 1}
118+
\end{gathered}
119+
\\
120+
(b \odot X)
121+
\end{equation}
122+
"""
123+
)
124+
assert tt_tprint(tt_normalrv_name_expr) == expected.strip()
19125

20126
tt_2_normalrv_noname_expr = tt.matrix("M") * NormalRV(
21127
tt.scalar("\\mu_2"), tt.scalar("\\sigma_2")
22128
)
23129
tt_2_normalrv_noname_expr *= tt.scalar("b") * NormalRV(
24130
tt_2_normalrv_noname_expr, tt.scalar("\\sigma")
25131
) + tt.scalar("c")
26-
expected = "M in R**(N^M_0 x N^M_1), \\mu_2 in R, \\sigma_2 in R\nb in R, \\sigma in R, c in R\na ~ N(\\mu_2, \\sigma_2**2) in R, d ~ N((M * a), \\sigma**2) in R**(N^d_0 x N^d_1)\n((M * a) * ((b * d) + c))"
27-
assert tt_pprint(tt_2_normalrv_noname_expr) == expected
132+
expected = textwrap.dedent(
133+
r"""
134+
\begin{equation}
135+
\begin{gathered}
136+
M \in \mathbb{R}^{N^{M}_{0} \times N^{M}_{1}}
137+
\\
138+
\mu_2 \in \mathbb{R}, \,\sigma_2 \in \mathbb{R}
139+
\\
140+
b \in \mathbb{R}, \,\sigma \in \mathbb{R}, \,c \in \mathbb{R}
141+
\\
142+
a \sim \operatorname{N}\left(\mu_2, {\sigma_2}^{2}\right)\, \in \mathbb{R}
143+
\\
144+
d \sim \operatorname{N}\left((M \odot a), {\sigma}^{2}\right)\, \in \mathbb{R}^{N^{d}_{0} \times N^{d}_{1}}
145+
\end{gathered}
146+
\\
147+
((M \odot a) \odot ((b \odot d) + c))
148+
\end{equation}
149+
"""
150+
)
151+
assert tt_tprint(tt_2_normalrv_noname_expr) == expected.strip()
152+
153+
expected = textwrap.dedent(
154+
r"""
155+
\begin{equation}
156+
\begin{gathered}
157+
b \in \mathbb{Z}, \,c \in \mathbb{Z}, \,M \in \mathbb{R}^{N^{M}_{0} \times N^{M}_{1}}
158+
\end{gathered}
159+
\\
160+
M\left[b, \,c\right]
161+
\end{equation}
162+
"""
163+
)
164+
# TODO: "c" should be "1".
165+
assert (
166+
tt_tprint(tt.matrix("M")[tt.iscalar("a"), tt.constant(1, dtype="int")]) == expected.strip()
167+
)
168+
169+
expected = textwrap.dedent(
170+
r"""
171+
\begin{equation}
172+
\begin{gathered}
173+
M \in \mathbb{R}^{N^{M}_{0} \times N^{M}_{1}}
174+
\end{gathered}
175+
\\
176+
M\left[1\right]
177+
\end{equation}
178+
"""
179+
)
180+
assert tt_tprint(tt.matrix("M")[1]) == expected.strip()
181+
182+
expected = textwrap.dedent(
183+
r"""
184+
\begin{equation}
185+
\begin{gathered}
186+
M \in \mathbb{N}^{N^{M}_{0}}
187+
\end{gathered}
188+
\\
189+
M\left[2:4:0\right]
190+
\end{equation}
191+
"""
192+
)
193+
assert tt_tprint(tt.vector("M", dtype="uint32")[0:4:2]) == expected.strip()
194+
195+
norm_rv = NormalRV(tt.scalar("\\mu"), tt.scalar("\\sigma"))
196+
rv_obs = observed(tt.constant(1.0, dtype=norm_rv.dtype), norm_rv)
197+
198+
expected = textwrap.dedent(
199+
r"""
200+
\begin{equation}
201+
\begin{gathered}
202+
\mu \in \mathbb{R}, \,\sigma \in \mathbb{R}
203+
\\
204+
a \sim \operatorname{N}\left(\mu, {\sigma}^{2}\right)\, \in \mathbb{R}
205+
\end{gathered}
206+
\\
207+
a = 1.0
208+
\end{equation}
209+
"""
210+
)
211+
assert tt_tprint(rv_obs) == expected.strip()

0 commit comments

Comments
 (0)