Skip to content

Commit b485769

Browse files
authored
fix docstrings (#28)
1 parent fbf0fa8 commit b485769

File tree

5 files changed

+21
-31
lines changed

5 files changed

+21
-31
lines changed

pymc_experimental/bart/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
__all__ = ["BART", "PGBART"]
2121

2222

23-
2423
import pymc as pm
24+
2525
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_experimental/bart/bart.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525

2626

2727
class BARTRV(RandomVariable):
28-
"""
29-
Base class for BART
30-
"""
28+
"""Base class for BART."""
3129

3230
name = "BART"
3331
ndim_supp = 1

pymc_experimental/bart/pgbart.py

+9-18
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
class PGBART(ArrayStepShared):
3434
"""
35-
Particle Gibss BART sampling step
35+
Particle Gibss BART sampling step.
3636
3737
Parameters
3838
----------
@@ -208,9 +208,7 @@ def astep(self, _):
208208
return self.sum_trees, [stats]
209209

210210
def normalize(self, particles):
211-
"""
212-
Use logsumexp trick to get W_t and softmax to get normalized_weights
213-
"""
211+
"""Use logsumexp trick to get W_t and softmax to get normalized_weights."""
214212
log_w = np.array([p.log_weight for p in particles])
215213
log_w_max = log_w.max()
216214
log_w_ = log_w - log_w_max
@@ -224,9 +222,7 @@ def normalize(self, particles):
224222
return W_t, normalized_weights
225223

226224
def init_particles(self, tree_id: int) -> np.ndarray:
227-
"""
228-
Initialize particles
229-
"""
225+
"""Initialize particles."""
230226
p = self.all_particles[tree_id]
231227
particles = [p]
232228
particles.append(copy(p))
@@ -238,7 +234,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
238234

239235
def update_weight(self, particle, old=False):
240236
"""
241-
Update the weight of a particle
237+
Update the weight of a particle.
242238
243239
Since the prior is used as the proposal,the weights are updated additively as the ratio of
244240
the new and old log-likelihoods.
@@ -253,19 +249,15 @@ def update_weight(self, particle, old=False):
253249

254250
@staticmethod
255251
def competence(var, has_grad):
256-
"""
257-
PGBART is only suitable for BART distributions
258-
"""
252+
"""PGBART is only suitable for BART distributions."""
259253
dist = getattr(var.owner, "op", None)
260254
if isinstance(dist, BARTRV):
261255
return Competence.IDEAL
262256
return Competence.INCOMPATIBLE
263257

264258

265259
class ParticleTree:
266-
"""
267-
Particle tree
268-
"""
260+
"""Particle tree."""
269261

270262
def __init__(self, tree):
271263
self.tree = tree.copy() # keeps the tree that we care at the moment
@@ -340,6 +332,7 @@ def rvs(self):
340332
def compute_prior_probability(alpha):
341333
"""
342334
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
335+
343336
Taken from equation 19 in [Rockova2018].
344337
345338
Parameters
@@ -463,7 +456,7 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
463456

464457

465458
def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std):
466-
"""Draw Gaussian distributed leaf values"""
459+
"""Draw Gaussian distributed leaf values."""
467460
if Y_mu_pred.size == 0:
468461
return 0
469462
else:
@@ -504,9 +497,7 @@ def discrete_uniform_sampler(upper_value):
504497

505498

506499
class NormalSampler:
507-
"""
508-
Cache samples from a standard normal distribution
509-
"""
500+
"""Cache samples from a standard normal distribution."""
510501

511502
def __init__(self):
512503
self.size = 1000

pymc_experimental/bart/tree.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
class Tree:
24-
"""Full binary tree
24+
"""Full binary tree.
2525
2626
A full binary tree is a tree where each node has exactly zero or two children.
2727
This structure is used as the basic component of the Bayesian Additive Regression Tree (BART)
@@ -135,17 +135,16 @@ def _traverse_tree(self, x, node_index=0):
135135
@staticmethod
136136
def init_tree(leaf_node_value, idx_data_points):
137137
"""
138+
Initialize tree.
138139
139140
Parameters
140141
----------
141142
leaf_node_value
142143
idx_data_points
143-
m : int
144-
number of trees in BART
145144
146145
Returns
147146
-------
148-
147+
tree
149148
"""
150149
new_tree = Tree(len(idx_data_points))
151150
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)

pymc_experimental/bart/utils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Utility function for variable selection and bart interpretability."""
2+
13
import arviz as az
24
import matplotlib.pyplot as plt
35
import numpy as np
@@ -10,7 +12,7 @@
1012

1113
def predict(idata, rng, X_new=None, size=None, excluded=None):
1214
"""
13-
Generate samples from the BART-posterior
15+
Generate samples from the BART-posterior.
1416
1517
Parameters
1618
----------
@@ -75,7 +77,7 @@ def plot_dependence(
7577
ax=None,
7678
):
7779
"""
78-
Partial dependence or individual conditional expectation plot
80+
Partial dependence or individual conditional expectation plot.
7981
8082
Parameters
8183
----------
@@ -107,12 +109,12 @@ def plot_dependence(
107109
instances : int
108110
Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots.
109111
random_seed : int
110-
random_seed used to sample from the posterior. Defaults to None.
112+
Seed used to sample from the posterior. Defaults to None.
111113
sharey : bool
112114
Controls sharing of properties among y-axes. Defaults to True.
113115
rug : bool
114116
Whether to include a rugplot. Defaults to True.
115-
smooth=True,
117+
smooth : bool
116118
If True the result will be smoothed by first computing a linear interpolation of the data
117119
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
118120
Defaults to True.
@@ -302,7 +304,7 @@ def plot_dependence(
302304

303305
def plot_variable_importance(idata, labels=None, figsize=None, samples=100, random_seed=None):
304306
"""
305-
Estimates variable importance from the BART-posterior
307+
Estimates variable importance from the BART-posterior.
306308
307309
Parameters
308310
----------

0 commit comments

Comments
 (0)