Skip to content

Commit 64e9375

Browse files
ricardoV94twiecki
authored andcommitted
Add type hints to several functions in sampling_jax.py
1 parent dd4b940 commit 64e9375

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

pymc/sampling_jax.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import sys
55
import warnings
66

7-
from typing import Callable, List, Optional
7+
from typing import Callable, Dict, List, Optional, Sequence, Union
88

9+
from pymc.initial_point import StartDict
910
from pymc.sampling import _init_jitter
1011

1112
xla_flags = os.getenv("XLA_FLAGS", "")
@@ -130,7 +131,7 @@ def _sample_stats_to_xarray(posterior):
130131
return data
131132

132133

133-
def _get_log_likelihood(model, samples):
134+
def _get_log_likelihood(model: Model, samples) -> Dict:
134135
"""Compute log-likelihood for all observations"""
135136
data = {}
136137
for v in model.observed_RVs:
@@ -142,8 +143,13 @@ def _get_log_likelihood(model, samples):
142143

143144

144145
def _get_batched_jittered_initial_points(
145-
model, chains, initvals, random_seed, jitter=True, jitter_max_retries=10
146-
):
146+
model: Model,
147+
chains: int,
148+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
149+
random_seed: int,
150+
jitter: bool = True,
151+
jitter_max_retries: int = 10,
152+
) -> Union[np.ndarray, List[np.ndarray]]:
147153
"""Get jittered initial point in format expected by NumPyro MCMC kernel
148154
149155
Returns
@@ -173,19 +179,19 @@ def _get_batched_jittered_initial_points(
173179

174180

175181
def sample_numpyro_nuts(
176-
draws=1000,
177-
tune=1000,
178-
chains=4,
179-
target_accept=0.8,
180-
random_seed=None,
181-
initvals=None,
182-
model=None,
182+
draws: int = 1000,
183+
tune: int = 1000,
184+
chains: int = 4,
185+
target_accept: float = 0.8,
186+
random_seed: int = None,
187+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
188+
model: Optional[Model] = None,
183189
var_names=None,
184-
progress_bar=True,
185-
keep_untransformed=False,
186-
chain_method="parallel",
187-
idata_kwargs=None,
188-
nuts_kwargs=None,
190+
progress_bar: bool = True,
191+
keep_untransformed: bool = False,
192+
chain_method: str = "parallel",
193+
idata_kwargs: Optional[Dict] = None,
194+
nuts_kwargs: Optional[Dict] = None,
189195
):
190196
from numpyro.infer import MCMC, NUTS
191197

0 commit comments

Comments
 (0)