32
32
33
33
def laplace (
34
34
vars : Sequence [Variable ],
35
- draws = 1_000 ,
35
+ draws : Optional [ int ] = 1000 ,
36
36
model = None ,
37
37
random_seed : Optional [RandomSeed ] = None ,
38
38
progressbar = True ,
@@ -49,9 +49,9 @@ def laplace(
49
49
vars : Sequence[Variable]
50
50
A sequence of variables for which the Laplace approximation of the posterior distribution
51
51
is to be created.
52
- draws : int, optional, default=1_000
52
+ draws : Optional[ int] with default=1_000
53
53
The number of draws to sample from the posterior distribution for creating the approximation.
54
- For draws=0 only the fit of the Laplace approximation is returned
54
+ For draws=None only the fit of the Laplace approximation is returned
55
55
model : object, optional, default=None
56
56
The model object that defines the posterior distribution. If None, the default model will be used.
57
57
random_seed : Optional[RandomSeed], optional, default=None
@@ -103,7 +103,8 @@ def laplace(
103
103
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html
104
104
untransformed_m = remove_value_transforms (transformed_m )
105
105
untransformed_vars = [untransformed_m [v .name ] for v in vars ]
106
- hessian = pm .find_hessian (point = map , vars = untransformed_vars , model = untransformed_m )
106
+ hessian = pm .find_hessian (
107
+ point = map , vars = untransformed_vars , model = untransformed_m )
107
108
108
109
if np .linalg .det (hessian ) == 0 :
109
110
raise np .linalg .LinAlgError ("Hessian is singular." )
@@ -113,12 +114,13 @@ def laplace(
113
114
114
115
chains = 1
115
116
116
- if draws != 0 :
117
+ if draws is not None :
117
118
samples = rng .multivariate_normal (mean , cov , size = (chains , draws ))
118
119
119
120
data_vars = {}
120
121
for i , var in enumerate (vars ):
121
- data_vars [str (var )] = xr .DataArray (samples [:, :, i ], dims = ("chain" , "draw" ))
122
+ data_vars [str (var )] = xr .DataArray (
123
+ samples [:, :, i ], dims = ("chain" , "draw" ))
122
124
123
125
coords = {"chain" : np .arange (chains ), "draw" : np .arange (draws )}
124
126
ds = xr .Dataset (data_vars , coords = coords )
@@ -136,13 +138,15 @@ def laplace(
136
138
def addFitToInferenceData (vars , idata , mean , covariance ):
137
139
coord_names = [v .name for v in vars ]
138
140
# Convert to xarray DataArray
139
- mean_dataarray = xr .DataArray (mean , dims = ["rows" ], coords = {"rows" : coord_names })
141
+ mean_dataarray = xr .DataArray (
142
+ mean , dims = ["rows" ], coords = {"rows" : coord_names })
140
143
cov_dataarray = xr .DataArray (
141
144
covariance , dims = ["rows" , "columns" ], coords = {"rows" : coord_names , "columns" : coord_names }
142
145
)
143
146
144
147
# Create xarray dataset
145
- dataset = xr .Dataset ({"mean_vector" : mean_dataarray , "covariance_matrix" : cov_dataarray })
148
+ dataset = xr .Dataset ({"mean_vector" : mean_dataarray ,
149
+ "covariance_matrix" : cov_dataarray })
146
150
147
151
idata .add_groups (fit = dataset )
148
152
0 commit comments