Skip to content

Commit 5d6e94c

Browse files
authored
Merge branch 'pymc-devs:main' into gpv4update
2 parents d5474a7 + 64d8396 commit 5d6e94c

37 files changed

+1472
-490
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
exclude: ^requirements-dev\.txt$
1515
- id: trailing-whitespace
1616
- repo: https://github.com/PyCQA/isort
17-
rev: 5.9.3
17+
rev: 5.10.1
1818
hooks:
1919
- id: isort
2020
name: isort

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ docker exec -it pymc jupyter notebook list
146146
## Style guide
147147

148148
We have configured a pre-commit hook that checks for `black`-compliant code style.
149-
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/PyMC-Python-Code-Style), because it will automatically enforce the code style on your commits.
149+
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/Python-Code-Style), because it will automatically enforce the code style on your commits.
150150

151151
Similarly, consult the [PyMC's Jupyter Notebook Style](https://github.com/pymc-devs/pymc/wiki/PyMC-Jupyter-Notebook-Style-Guide) guides for notebooks.
152152

RELEASE-NOTES.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
9393
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
9494
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
9595
- New features for BART:
96-
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
97-
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
96+
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
97+
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
98+
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
9899
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
99100
- ...
100101

conda-envs/environment-dev-py37.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/environment-dev-py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/environment-dev-py39.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/environment-test-py37.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/environment-test-py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl==0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/environment-test-py39.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools

conda-envs/windows-environment-dev-py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.4
1010
- cachetools>=4.2.1

conda-envs/windows-environment-test-py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7-
- aeppl>=0.0.13
7+
- aeppl=0.0.15
88
- aesara>=2.2.6
99
- arviz>=0.11.2
1010
- cachetools

docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import os
2020
import sys
2121

22-
2322
# If extensions (or modules to document with autodoc) are in another directory,
2423
# add these directories to sys.path here. If the directory is relative to the
2524
# documentation root, use os.path.abspath to make it absolute, like shown here.

pymc/aesaraf.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def rvs_to_value_vars(
335335
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
336336
**kwargs,
337337
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
338-
"""Replace random variables in graphs with their value variables.
338+
"""Clone and replace random variables in graphs with their value variables.
339339
340340
This will *not* recompute test values in the resulting graphs.
341341
@@ -383,6 +383,16 @@ def transform_replacements(var, replacements):
383383
# Walk the transformed variable and make replacements
384384
return [trans_rv_value]
385385

386+
# Clone original graphs
387+
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
388+
equiv = clone_get_equiv(inputs, graphs, False, False, {})
389+
graphs = [equiv[n] for n in graphs]
390+
391+
if initial_replacements:
392+
initial_replacements = {
393+
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
394+
}
395+
386396
return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
387397

388398

pymc/backends/arviz.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,26 @@
4242
Var = Any # pylint: disable=invalid-name
4343

4444

45+
def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
46+
"""If there are observations available, return them as a dictionary."""
47+
if model is None:
48+
return None
49+
50+
observations = {}
51+
for obs in model.observed_RVs:
52+
aux_obs = getattr(obs.tag, "observations", None)
53+
if aux_obs is not None:
54+
try:
55+
obs_data = extract_obs_data(aux_obs)
56+
observations[obs.name] = obs_data
57+
except TypeError:
58+
warnings.warn(f"Could not extract data from symbolic observation {obs}")
59+
else:
60+
warnings.warn(f"No data for observation {obs}")
61+
62+
return observations
63+
64+
4565
class _DefaultTrace:
4666
"""
4767
Utility for collecting samples into a dictionary.
@@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
196216
self.dims = {**model_dims, **self.dims}
197217

198218
self.density_dist_obs = density_dist_obs
199-
self.observations = self.find_observations()
200-
201-
def find_observations(self) -> Optional[Dict[str, Var]]:
202-
"""If there are observations available, return them as a dictionary."""
203-
if self.model is None:
204-
return None
205-
observations = {}
206-
for obs in self.model.observed_RVs:
207-
aux_obs = getattr(obs.tag, "observations", None)
208-
if aux_obs is not None:
209-
try:
210-
obs_data = extract_obs_data(aux_obs)
211-
observations[obs.name] = obs_data
212-
except TypeError:
213-
warnings.warn(f"Could not extract data from symbolic observation {obs}")
214-
else:
215-
warnings.warn(f"No data for observation {obs}")
216-
217-
return observations
219+
self.observations = find_observations(self.model)
218220

219221
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
220222
"""Split MultiTrace object into posterior and warmup.

pymc/bart/pgbart.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
101101
Number of particles for the conditional SMC sampler. Defaults to 10
102102
max_stages : int
103103
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
104-
batch : int
104+
batch : int or tuple
105105
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106-
during tuning and 20% after tuning.
106+
during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
107+
during tuning and the second the batch size after tuning.
107108
model: PyMC Model
108109
Optional model for sampling step. Defaults to None (taken from context).
109110
@@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
138139
self.alpha = self.bart.alpha
139140
self.k = self.bart.k
140141
self.response = self.bart.response
141-
self.split_prior = self.bart.split_prior
142-
if self.split_prior is None:
143-
self.split_prior = np.ones(self.X.shape[1])
142+
self.alpha_vec = self.bart.split_prior
143+
if self.alpha_vec is None:
144+
self.alpha_vec = np.ones(self.X.shape[1])
144145

145146
self.init_mean = self.Y.mean()
146147
# if data is binary
@@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
149150
self.mu_std = 6 / (self.k * self.m ** 0.5)
150151
# maybe we need to check for count data
151152
else:
152-
self.mu_std = self.Y.std() / (self.k * self.m ** 0.5)
153+
self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5)
153154

154155
self.num_observations = self.X.shape[0]
155156
self.num_variates = self.X.shape[1]
@@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
167168

168169
self.normal = NormalSampler()
169170
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
170-
self.ssv = SampleSplittingVariable(self.split_prior)
171+
self.ssv = SampleSplittingVariable(self.alpha_vec)
171172

172173
self.tune = True
173-
self.idx = 0
174-
self.batch = batch
175174

176-
if self.batch == "auto":
177-
self.batch = max(1, int(self.m * 0.1))
175+
if batch == "auto":
176+
self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2)))
177+
else:
178+
if isinstance(batch, (tuple, list)):
179+
self.batch = batch
180+
else:
181+
self.batch = (batch, batch)
182+
178183
self.log_num_particles = np.log(num_particles)
179184
self.indices = list(range(1, num_particles))
180185
self.len_indices = len(self.indices)
@@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
187192
self.all_particles = []
188193
for i in range(self.m):
189194
self.a_tree.tree_id = i
195+
self.a_tree.leaf_node_value = (
196+
self.init_mean / self.m + self.normal.random() * self.mu_std,
197+
)
190198
p = ParticleTree(
191199
self.a_tree,
192200
self.init_log_weight,
@@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
201209
sum_trees_output = q.data
202210
variable_inclusion = np.zeros(self.num_variates, dtype="int")
203211

204-
if self.idx == self.m:
205-
self.idx = 0
206-
207-
for tree_id in range(self.idx, self.idx + self.batch):
208-
if tree_id >= self.m:
209-
break
212+
tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune])
213+
for tree_id in tree_ids:
210214
# Generate an initial set of SMC particles
211215
# at the end of the algorithm we return one of these particles as the new tree
212216
particles = self.init_particles(tree_id)
213217
# Compute the sum of trees without the tree we are attempting to replace
214218
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
215219

216220
# The old tree is not growing so we update the weights only once.
217-
self.update_weight(particles[0])
221+
self.update_weight(particles[0], new=True)
218222
for t in range(self.max_stages):
219223
# Sample each particle (try to grow each tree), except for the first one.
220224
for p in particles[1:]:
@@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
235239
if tree_grew:
236240
self.update_weight(p)
237241
# Normalize weights
238-
W_t, normalized_weights = self.normalize(particles)
242+
W_t, normalized_weights = self.normalize(particles[1:])
239243

240244
# Resample all but first particle
241-
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
245+
re_n_w = normalized_weights
242246
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
243247
particles[1:] = particles[new_indices]
244248

245249
# Set the new weights
246-
for p in particles:
250+
for p in particles[1:]:
247251
p.log_weight = W_t
248252

249253
# Check if particles can keep growing, otherwise stop iterating
@@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
254258
if all(non_available_nodes_for_expansion):
255259
break
256260

261+
for p in particles[1:]:
262+
p.log_weight = p.old_likelihood_logp
263+
264+
_, normalized_weights = self.normalize(particles)
257265
# Get the new tree and update
258266
new_particle = np.random.choice(particles, p=normalized_weights)
259267
new_tree = new_particle.tree
260-
self.all_trees[self.idx] = new_tree
268+
self.all_trees[tree_id] = new_tree
261269
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
262270
self.all_particles[tree_id] = new_particle
263271
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
264272

265273
if self.tune:
274+
self.ssv = SampleSplittingVariable(self.alpha_vec)
266275
for index in new_particle.used_variates:
267-
self.split_prior[index] += 1
268-
self.ssv = SampleSplittingVariable(self.split_prior)
276+
self.alpha_vec[index] += 1
269277
else:
270-
self.batch = max(1, int(self.m * 0.2))
271278
for index in new_particle.used_variates:
272279
variable_inclusion[index] += 1
273-
self.idx += 1
274280

275281
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
276282
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
@@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
323329

324330
return np.array(particles)
325331

326-
def update_weight(self, particle: List[ParticleTree]) -> None:
332+
def update_weight(self, particle: List[ParticleTree], new=False) -> None:
327333
"""
328334
Update the weight of a particle
329335
@@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
333339
new_likelihood = self.likelihood_logp(
334340
self.sum_trees_output_noi + particle.tree.predict_output()
335341
)
336-
particle.log_weight += new_likelihood - particle.old_likelihood_logp
337-
particle.old_likelihood_logp = new_likelihood
342+
if new:
343+
particle.log_weight = new_likelihood
344+
else:
345+
particle.log_weight += new_likelihood - particle.old_likelihood_logp
346+
particle.old_likelihood_logp = new_likelihood
338347

339348

340349
class SampleSplittingVariable:
341-
def __init__(self, alpha_prior):
350+
def __init__(self, alpha_vec):
342351
"""
343-
Sample splitting variables proportional to `alpha_prior`.
352+
Sample splitting variables proportional to `alpha_vec`.
344353
345-
This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
346-
parameter and then using those weights to sample from the available spliting variables.
354+
This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
347355
This enforce sparsity.
348356
"""
349-
self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum())))
357+
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
350358

351359
def rvs(self):
352360
r = np.random.random()

pymc/distributions/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
from pymc.distributions.logprob import ( # isort:skip
16-
_logcdf,
1716
logcdf,
1817
logp,
19-
logcdfpt,
2018
logp_transform,
2119
logpt,
2220
logpt_sum,
@@ -195,6 +193,5 @@
195193
"logp",
196194
"logp_transform",
197195
"logcdf",
198-
"_logcdf",
199196
"logpt_sum",
200197
]

0 commit comments

Comments
 (0)