Skip to content

Commit 4a8de57

Browse files
committed
Add optional rng argument to Metropolis Proposals
1 parent c3833bc commit 4a8de57

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

pymc/step_methods/metropolis.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Dict, List, Tuple
14+
from typing import Any, Callable, Dict, List, Optional, Tuple
1515

1616
import numpy as np
1717
import numpy.random as nr
@@ -50,50 +50,64 @@
5050

5151

5252
class Proposal:
53-
def __init__(self, s):
53+
def __init__(self, s, rng_seed: Optional[int] = None):
5454
self.s = s
55+
self.rng = np.random.default_rng(rng_seed)
5556

5657

5758
class NormalProposal(Proposal):
58-
def __call__(self):
59-
return nr.normal(scale=self.s)
59+
def __call__(self, rng: Optional[np.random.Generator] = None):
60+
if rng is None:
61+
rng = self.rng
62+
return rng.normal(scale=self.s)
6063

6164

6265
class UniformProposal(Proposal):
63-
def __call__(self):
64-
return nr.uniform(low=-self.s, high=self.s, size=len(self.s))
66+
def __call__(self, rng: Optional[np.random.Generator] = None):
67+
if rng is None:
68+
rng = self.rng
69+
return rng.uniform(low=-self.s, high=self.s, size=len(self.s))
6570

6671

6772
class CauchyProposal(Proposal):
68-
def __call__(self):
69-
return nr.standard_cauchy(size=np.size(self.s)) * self.s
73+
def __call__(self, rng: Optional[np.random.Generator] = None):
74+
if rng is None:
75+
rng = self.rng
76+
return rng.standard_cauchy(size=np.size(self.s)) * self.s
7077

7178

7279
class LaplaceProposal(Proposal):
73-
def __call__(self):
80+
def __call__(self, rng: Optional[np.random.Generator] = None):
81+
if rng is None:
82+
rng = self.rng
7483
size = np.size(self.s)
75-
return (nr.standard_exponential(size=size) - nr.standard_exponential(size=size)) * self.s
84+
return (rng.standard_exponential(size=size) - rng.standard_exponential(size=size)) * self.s
7685

7786

7887
class PoissonProposal(Proposal):
79-
def __call__(self):
80-
return nr.poisson(lam=self.s, size=np.size(self.s)) - self.s
88+
def __call__(self, rng: Optional[np.random.Generator] = None):
89+
if rng is None:
90+
rng = self.rng
91+
return rng.poisson(lam=self.s, size=np.size(self.s)) - self.s
8192

8293

8394
class MultivariateNormalProposal(Proposal):
84-
def __init__(self, s):
85-
n, m = s.shape
95+
def __init__(self, *args, **kwargs):
96+
super().__init__(*args, **kwargs)
97+
n, m = self.s.shape
8698
if n != m:
8799
raise ValueError("Covariance matrix is not symmetric.")
88100
self.n = n
89-
self.chol = scipy.linalg.cholesky(s, lower=True)
101+
self.chol = scipy.linalg.cholesky(self.s, lower=True)
90102

91-
def __call__(self, num_draws=None):
103+
def __call__(self, num_draws=None, rng: Optional[np.random.Generator] = None):
104+
if rng is None:
105+
rng = self.rng
92106
if num_draws is not None:
93-
b = np.random.randn(self.n, num_draws)
107+
b = rng.normal(size=(self.n, num_draws))
94108
return np.dot(self.chol, b).T
95109
else:
96-
b = np.random.randn(self.n)
110+
b = rng.normal(size=self.n)
97111
return np.dot(self.chol, b)
98112

99113

0 commit comments

Comments
 (0)