11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import numpy as np
15
+ import pytest
16
+
17
+ from pytensor .compile import SharedVariable
14
18
from pytensor .graph import Constant
15
19
20
+ from pymc import Deterministic
16
21
from pymc .data import Data
17
22
from pymc .distributions import HalfNormal , Normal
18
23
from pymc .model import Model
19
24
from pymc .model .transform .optimization import freeze_dims_and_data
20
25
21
26
22
- def test_freeze_existing_rv_dims_and_data ():
27
+ def test_freeze_dims_and_data ():
23
28
with Model (coords = {"test_dim" : range (5 )}) as m :
24
- std = Data ("std " , [1 ])
29
+ std = Data ("test_data " , [1 ])
25
30
x = HalfNormal ("x" , std , dims = ("test_dim" ,))
26
31
y = Normal ("y" , shape = x .shape [0 ] + 1 )
27
32
@@ -34,18 +39,96 @@ def test_freeze_existing_rv_dims_and_data():
34
39
assert y_logp .type .shape == (None ,)
35
40
36
41
frozen_m = freeze_dims_and_data (m )
37
- std , x , y = frozen_m ["std " ], frozen_m ["x" ], frozen_m ["y" ]
42
+ data , x , y = frozen_m ["test_data " ], frozen_m ["x" ], frozen_m ["y" ]
38
43
x_logp , y_logp = frozen_m .logp (sum = False )
39
- assert isinstance (std , Constant )
44
+ assert isinstance (data , Constant )
40
45
assert x .type .shape == (5 ,)
41
46
assert y .type .shape == (6 ,)
42
47
assert x_logp .type .shape == (5 ,)
43
48
assert y_logp .type .shape == (6 ,)
44
49
50
+ # Test trying to update a frozen data or dim raises an informative error
51
+ with frozen_m :
52
+ with pytest .raises (TypeError , match = "The variable `test_data` must be a `SharedVariable`" ):
53
+ frozen_m .set_data ("test_data" , values = [2 ])
54
+ with pytest .raises (
55
+ TypeError , match = "The dim_length of `test_dim` must be a `SharedVariable`"
56
+ ):
57
+ frozen_m .set_dim ("test_dim" , new_length = 6 , coord_values = range (6 ))
58
+
59
+ # Test we can still update original model
60
+ with m :
61
+ m .set_data ("test_data" , values = [2 ])
62
+ m .set_dim ("test_dim" , new_length = 6 , coord_values = range (6 ))
63
+ assert m ["test_data" ].get_value () == [2 ]
64
+ assert m .dim_lengths ["test_dim" ].get_value () == 6
45
65
46
- def test_freeze_rv_dims_nothing_to_change ():
66
+
67
+ def test_freeze_dims_nothing_to_change ():
47
68
with Model (coords = {"test_dim" : range (5 )}) as m :
48
69
x = HalfNormal ("x" , shape = (5 ,))
49
70
y = Normal ("y" , shape = x .shape [0 ] + 1 )
50
71
51
72
assert m .point_logps () == freeze_dims_and_data (m ).point_logps ()
73
+
74
+
75
+ def test_freeze_dims_and_data_subset ():
76
+ with Model (coords = {"dim1" : range (3 ), "dim2" : range (5 )}) as m :
77
+ data1 = Data ("data1" , [1 , 2 , 3 ], dims = "dim1" )
78
+ data2 = Data ("data2" , [1 , 2 , 3 , 4 , 5 ], dims = "dim2" )
79
+ var1 = Normal ("var1" , dims = "dim1" )
80
+ var2 = Normal ("var2" , dims = "dim2" )
81
+ x = data1 * var1
82
+ y = data2 * var2
83
+ det = Deterministic ("det" , x [:, None ] + y [None , :])
84
+
85
+ assert det .type .shape == (None , None )
86
+
87
+ new_m = freeze_dims_and_data (m , dims = ["dim1" ], data = [])
88
+ assert new_m ["det" ].type .shape == (3 , None )
89
+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
90
+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
91
+ assert isinstance (new_m ["data1" ], SharedVariable )
92
+ assert isinstance (new_m ["data2" ], SharedVariable )
93
+
94
+ new_m = freeze_dims_and_data (m , dims = ["dim2" ], data = [])
95
+ assert new_m ["det" ].type .shape == (None , 5 )
96
+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
97
+ assert isinstance (new_m .dim_lengths ["dim2" ], Constant ) and new_m .dim_lengths ["dim2" ].data == 5
98
+ assert isinstance (new_m ["data1" ], SharedVariable )
99
+ assert isinstance (new_m ["data2" ], SharedVariable )
100
+
101
+ new_m = freeze_dims_and_data (m , dims = ["dim1" , "dim2" ], data = [])
102
+ assert new_m ["det" ].type .shape == (3 , 5 )
103
+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
104
+ assert isinstance (new_m .dim_lengths ["dim2" ], Constant ) and new_m .dim_lengths ["dim2" ].data == 5
105
+ assert isinstance (new_m ["data1" ], SharedVariable )
106
+ assert isinstance (new_m ["data2" ], SharedVariable )
107
+
108
+ new_m = freeze_dims_and_data (m , dims = [], data = ["data1" ])
109
+ assert new_m ["det" ].type .shape == (3 , None )
110
+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
111
+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
112
+ assert isinstance (new_m ["data1" ], Constant ) and np .all (new_m ["data1" ].data == [1 , 2 , 3 ])
113
+ assert isinstance (new_m ["data2" ], SharedVariable )
114
+
115
+ new_m = freeze_dims_and_data (m , dims = [], data = ["data2" ])
116
+ assert new_m ["det" ].type .shape == (None , 5 )
117
+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
118
+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
119
+ assert isinstance (new_m ["data1" ], SharedVariable )
120
+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
121
+
122
+ new_m = freeze_dims_and_data (m , dims = [], data = ["data1" , "data2" ])
123
+ assert new_m ["det" ].type .shape == (3 , 5 )
124
+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
125
+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
126
+ assert isinstance (new_m ["data1" ], Constant ) and np .all (new_m ["data1" ].data == [1 , 2 , 3 ])
127
+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
128
+
129
+ new_m = freeze_dims_and_data (m , dims = ["dim1" ], data = ["data2" ])
130
+ assert new_m ["det" ].type .shape == (3 , 5 )
131
+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
132
+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
133
+ assert isinstance (new_m ["data1" ], SharedVariable )
134
+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
0 commit comments