Skip to content

Commit 4dd5273

Browse files
committed
Improved documentation with an example
1 parent b38c2f0 commit 4dd5273

File tree

2 files changed

+95
-9
lines changed

2 files changed

+95
-9
lines changed

pymc_extras/distributions/transforms/partial_order.py

+87-7
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,66 @@ def padded_where(x, to_len, padval=-1):
3434
class PartialOrder(Transform):
3535
"""Create a PartialOrder transform
3636
37-
This is a more flexible version of the pymc ordered transform that
37+
A more flexible version of the pymc ordered transform that
3838
allows specifying a (strict) partial order on the elements.
3939
40-
It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41-
where N is the number of nodes in the dag and
42-
D is the maximum in-degree of a node in the transitive reduction.
43-
40+
Examples
41+
--------
42+
.. code:: python
43+
44+
import numpy as np
45+
import pymc as pm
46+
import pymc_extras as pmx
47+
48+
# Define two partial orders on 4 elements
49+
# am[i,j] = 1 means i < j
50+
adj_mats = np.array([
51+
# 0 < {1, 2} < 3
52+
[[0, 1, 1, 0],
53+
[0, 0, 0, 1],
54+
[0, 0, 0, 1],
55+
[0, 0, 0, 0]],
56+
57+
# 1 < 0 < 3 < 2
58+
[[0, 0, 0, 1],
59+
[1, 0, 0, 0],
60+
[0, 0, 0, 0],
61+
[0, 0, 1, 0]],
62+
])
63+
64+
# Create the partial order from the adjacency matrices
65+
po = pmx.PartialOrder(adj_mats)
66+
67+
with pm.Model() as model:
68+
# Generate 3 samples from both partial orders
69+
pm.Normal("po_vals", shape=(3,2,4), transform=po,
70+
initval=po.initvals((3,2,4)))
71+
72+
idata = pm.sample()
73+
74+
# Verify that for first po, the zeroth element is always the smallest
75+
assert (idata.posterior['po_vals'][:,:,:,0,0] <
76+
idata.posterior['po_vals'][:,:,:,0,1:]).all()
77+
78+
# Verify that for second po, the second element is always the largest
79+
assert (idata.posterior['po_vals'][:,:,:,1,2] >=
80+
idata.posterior['po_vals'][:,:,:,1,:]).all()
81+
82+
Technical notes
83+
----------------
84+
Partial order needs to be strict, i.e. without equalities.
85+
A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
86+
Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
87+
where N is the number of nodes in the dag and D is the maximum
88+
in-degree of a node in the transitive reduction.
4489
"""
4590

4691
name = "partial_order"
4792

4893
def __init__(self, adj_mat):
4994
"""
95+
Initialize the PartialOrder transform
96+
5097
Parameters
5198
----------
5299
adj_mat: ndarray
@@ -99,10 +146,43 @@ def __init__(self, adj_mat):
99146
self.dag = np.swapaxes(dag_T, -2, -1)
100147
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
101148

102-
def initvals(self, lower=-1, upper=1):
149+
def initvals(self, shape=None, lower=-1, upper=1):
150+
"""
151+
Create a set of appropriate initial values for the variable.
152+
NB! It is important that proper initial values are used,
153+
as only properly ordered values are in the range of the transform.
154+
155+
Parameters
156+
----------
157+
shape: tuple, default None
158+
shape of the initial values. If None, adj_mat[:-1] is used
159+
lower: float, default -1
160+
lower bound for the initial values
161+
upper: float, default 1
162+
upper bound for the initial values
163+
164+
Returns
165+
-------
166+
vals: ndarray
167+
initial values for the transformed variable
168+
"""
169+
170+
if shape is None:
171+
shape = self.dag.shape[:-1]
172+
173+
if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
174+
raise ValueError("Shape must match the shape of the adjacency matrix")
175+
176+
# Create the initial values
103177
vals = np.linspace(lower, upper, self.dag.shape[-2])
104178
inds = np.argsort(self.ts_inds, axis=-1)
105-
return vals[inds]
179+
ivals = vals[inds]
180+
181+
# Expand the initial values to the extra dimensions
182+
extra_dims = shape[: -len(self.dag.shape[:-1])]
183+
ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
184+
185+
return ivals
106186

107187
def backward(self, value, *inputs):
108188
minv = dtype_minval(value.dtype)

tests/distributions/test_transform.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,17 @@ def test_forward_backward_dimensionality(self):
5555
def test_sample_model(self):
5656
po = PartialOrder(self.adj_mats)
5757
with pm.Model() as model:
58-
x = pm.Normal("x", size=(2, 4), transform=po, initval=po.initvals(-1, 1))
58+
x = pm.Normal(
59+
"x",
60+
size=(3, 2, 4),
61+
transform=po,
62+
initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
63+
)
5964
idata = pm.sample()
6065

6166
# Check that the order constraints are satisfied
62-
xvs = idata.posterior.x.values.transpose(2, 3, 0, 1)
67+
# Move chain, draw and "3" dimensions to the back
68+
xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
6369
x0 = xvs[0] # 0 < {1, 2} < 3
6470
assert (
6571
(x0[0] < x0[1]).all()

0 commit comments

Comments
 (0)