@@ -34,19 +34,66 @@ def padded_where(x, to_len, padval=-1):
34
34
class PartialOrder (Transform ):
35
35
"""Create a PartialOrder transform
36
36
37
- This is a more flexible version of the pymc ordered transform that
37
+ A more flexible version of the pymc ordered transform that
38
38
allows specifying a (strict) partial order on the elements.
39
39
40
- It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41
- where N is the number of nodes in the dag and
42
- D is the maximum in-degree of a node in the transitive reduction.
43
-
40
+ Examples
41
+ --------
42
+ .. code:: python
43
+
44
+ import numpy as np
45
+ import pymc as pm
46
+ import pymc_extras as pmx
47
+
48
+ # Define two partial orders on 4 elements
49
+ # am[i,j] = 1 means i < j
50
+ adj_mats = np.array([
51
+ # 0 < {1, 2} < 3
52
+ [[0, 1, 1, 0],
53
+ [0, 0, 0, 1],
54
+ [0, 0, 0, 1],
55
+ [0, 0, 0, 0]],
56
+
57
+ # 1 < 0 < 3 < 2
58
+ [[0, 0, 0, 1],
59
+ [1, 0, 0, 0],
60
+ [0, 0, 0, 0],
61
+ [0, 0, 1, 0]],
62
+ ])
63
+
64
+ # Create the partial order from the adjacency matrices
65
+ po = pmx.PartialOrder(adj_mats)
66
+
67
+ with pm.Model() as model:
68
+ # Generate 3 samples from both partial orders
69
+ pm.Normal("po_vals", shape=(3,2,4), transform=po,
70
+ initval=po.initvals((3,2,4)))
71
+
72
+ idata = pm.sample()
73
+
74
+ # Verify that for first po, the zeroth element is always the smallest
75
+ assert (idata.posterior['po_vals'][:,:,:,0,0] <
76
+ idata.posterior['po_vals'][:,:,:,0,1:]).all()
77
+
78
+ # Verify that for second po, the second element is always the largest
79
+ assert (idata.posterior['po_vals'][:,:,:,1,2] >=
80
+ idata.posterior['po_vals'][:,:,:,1,:]).all()
81
+
82
+ Technical notes
83
+ ----------------
84
+ Partial order needs to be strict, i.e. without equalities.
85
+ A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
86
+ Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
87
+ where N is the number of nodes in the dag and D is the maximum
88
+ in-degree of a node in the transitive reduction.
44
89
"""
45
90
46
91
name = "partial_order"
47
92
48
93
def __init__ (self , adj_mat ):
49
94
"""
95
+ Initialize the PartialOrder transform
96
+
50
97
Parameters
51
98
----------
52
99
adj_mat: ndarray
@@ -99,10 +146,43 @@ def __init__(self, adj_mat):
99
146
self .dag = np .swapaxes (dag_T , - 2 , - 1 )
100
147
self .is_start = np .all (self .dag [..., :, :] == - 1 , axis = - 1 )
101
148
102
- def initvals (self , lower = - 1 , upper = 1 ):
149
+ def initvals (self , shape = None , lower = - 1 , upper = 1 ):
150
+ """
151
+ Create a set of appropriate initial values for the variable.
152
+ NB! It is important that proper initial values are used,
153
+ as only properly ordered values are in the range of the transform.
154
+
155
+ Parameters
156
+ ----------
157
+ shape: tuple, default None
158
+ shape of the initial values. If None, adj_mat[:-1] is used
159
+ lower: float, default -1
160
+ lower bound for the initial values
161
+ upper: float, default 1
162
+ upper bound for the initial values
163
+
164
+ Returns
165
+ -------
166
+ vals: ndarray
167
+ initial values for the transformed variable
168
+ """
169
+
170
+ if shape is None :
171
+ shape = self .dag .shape [:- 1 ]
172
+
173
+ if shape [- len (self .dag .shape [:- 1 ]) :] != self .dag .shape [:- 1 ]:
174
+ raise ValueError ("Shape must match the shape of the adjacency matrix" )
175
+
176
+ # Create the initial values
103
177
vals = np .linspace (lower , upper , self .dag .shape [- 2 ])
104
178
inds = np .argsort (self .ts_inds , axis = - 1 )
105
- return vals [inds ]
179
+ ivals = vals [inds ]
180
+
181
+ # Expand the initial values to the extra dimensions
182
+ extra_dims = shape [: - len (self .dag .shape [:- 1 ])]
183
+ ivals = np .tile (ivals , extra_dims + tuple ([1 ] * len (self .dag .shape [:- 1 ])))
184
+
185
+ return ivals
106
186
107
187
def backward (self , value , * inputs ):
108
188
minv = dtype_minval (value .dtype )
0 commit comments