4
4
import sys
5
5
import warnings
6
6
7
- from typing import Callable , List , Optional
7
+ from typing import Callable , Dict , List , Optional , Sequence , Union
8
8
9
+ from pymc .initial_point import StartDict
9
10
from pymc .sampling import _init_jitter
10
11
11
12
xla_flags = os .getenv ("XLA_FLAGS" , "" )
@@ -130,7 +131,7 @@ def _sample_stats_to_xarray(posterior):
130
131
return data
131
132
132
133
133
- def _get_log_likelihood (model , samples ):
134
+ def _get_log_likelihood (model : Model , samples ) -> Dict :
134
135
"""Compute log-likelihood for all observations"""
135
136
data = {}
136
137
for v in model .observed_RVs :
@@ -142,8 +143,13 @@ def _get_log_likelihood(model, samples):
142
143
143
144
144
145
def _get_batched_jittered_initial_points (
145
- model , chains , initvals , random_seed , jitter = True , jitter_max_retries = 10
146
- ):
146
+ model : Model ,
147
+ chains : int ,
148
+ initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]],
149
+ random_seed : int ,
150
+ jitter : bool = True ,
151
+ jitter_max_retries : int = 10 ,
152
+ ) -> Union [np .ndarray , List [np .ndarray ]]:
147
153
"""Get jittered initial point in format expected by NumPyro MCMC kernel
148
154
149
155
Returns
@@ -173,19 +179,19 @@ def _get_batched_jittered_initial_points(
173
179
174
180
175
181
def sample_numpyro_nuts (
176
- draws = 1000 ,
177
- tune = 1000 ,
178
- chains = 4 ,
179
- target_accept = 0.8 ,
180
- random_seed = None ,
181
- initvals = None ,
182
- model = None ,
182
+ draws : int = 1000 ,
183
+ tune : int = 1000 ,
184
+ chains : int = 4 ,
185
+ target_accept : float = 0.8 ,
186
+ random_seed : int = None ,
187
+ initvals : Optional [ Union [ StartDict , Sequence [ Optional [ StartDict ]]]] = None ,
188
+ model : Optional [ Model ] = None ,
183
189
var_names = None ,
184
- progress_bar = True ,
185
- keep_untransformed = False ,
186
- chain_method = "parallel" ,
187
- idata_kwargs = None ,
188
- nuts_kwargs = None ,
190
+ progress_bar : bool = True ,
191
+ keep_untransformed : bool = False ,
192
+ chain_method : str = "parallel" ,
193
+ idata_kwargs : Optional [ Dict ] = None ,
194
+ nuts_kwargs : Optional [ Dict ] = None ,
189
195
):
190
196
from numpyro .infer import MCMC , NUTS
191
197
0 commit comments