@@ -103,8 +103,7 @@ def laplace(
103
103
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html
104
104
untransformed_m = remove_value_transforms (transformed_m )
105
105
untransformed_vars = [untransformed_m [v .name ] for v in vars ]
106
- hessian = pm .find_hessian (
107
- point = map , vars = untransformed_vars , model = untransformed_m )
106
+ hessian = pm .find_hessian (point = map , vars = untransformed_vars , model = untransformed_m )
108
107
109
108
if np .linalg .det (hessian ) == 0 :
110
109
raise np .linalg .LinAlgError ("Hessian is singular." )
@@ -119,8 +118,7 @@ def laplace(
119
118
120
119
data_vars = {}
121
120
for i , var in enumerate (vars ):
122
- data_vars [str (var )] = xr .DataArray (
123
- samples [:, :, i ], dims = ("chain" , "draw" ))
121
+ data_vars [str (var )] = xr .DataArray (samples [:, :, i ], dims = ("chain" , "draw" ))
124
122
125
123
coords = {"chain" : np .arange (chains ), "draw" : np .arange (draws )}
126
124
ds = xr .Dataset (data_vars , coords = coords )
@@ -138,15 +136,13 @@ def laplace(
138
136
def addFitToInferenceData (vars , idata , mean , covariance ):
139
137
coord_names = [v .name for v in vars ]
140
138
# Convert to xarray DataArray
141
- mean_dataarray = xr .DataArray (
142
- mean , dims = ["rows" ], coords = {"rows" : coord_names })
139
+ mean_dataarray = xr .DataArray (mean , dims = ["rows" ], coords = {"rows" : coord_names })
143
140
cov_dataarray = xr .DataArray (
144
141
covariance , dims = ["rows" , "columns" ], coords = {"rows" : coord_names , "columns" : coord_names }
145
142
)
146
143
147
144
# Create xarray dataset
148
- dataset = xr .Dataset ({"mean_vector" : mean_dataarray ,
149
- "covariance_matrix" : cov_dataarray })
145
+ dataset = xr .Dataset ({"mean_vector" : mean_dataarray , "covariance_matrix" : cov_dataarray })
150
146
151
147
idata .add_groups (fit = dataset )
152
148
0 commit comments