Skip to content

Commit 64c1464

Browse files
bwengalstwiecki
andauthored
v4 refactor for GP module (#5055)
* np.int -> int, fix np DepricationWarning * remove shape arg from non-kron implementations, TODO for is_observed in Marginal, mark for deprecation * np.int -> int in gp/util.py * force all mean_func, cov_func args to GP constructors to be required kwargs (often default zero mean_func is used) * fix predictt functions, rename to _predict_at. because theano -> aesara * fix TP tests, force mean_func, cov_func to be req kwarg * fix TP reparameterization to sample studentt instead of chi2/norm * change naming shape->size where appropriate * add deprecation warning for is_observed * add jitter arg for covs headed for cholesky decomps, previously fixed at 1e-6. add deprecation warning for is_observed arg * clean up trivial aesara.function usage to .eval() instead * fix gp.util.replace_with_values to handle case with no symbolic values, .eval() works * jitter=0 for conditonals/predicts, fix replace_with_values calls * fix more tests - use model.logp instead of variable.logp - set req kwargs cov_func and mean_func - fix weirdly small scale on some input X, y - move predict calls into model block - the two kron models outstanding * black stuff * small fixes to get kronlatent and kronmarginal to pass * remove leftover prints * dep warning -> future warning * roll back mkl and mkl-service version * fix precommit * remove old DeprecationWarning * improve tests cleanup gp.util.replace_with_values and add docstrings * fix pre-commit issue * fix precommit on cov.py * fix comment * dont force blas version in windows dev enviornment (roll back changes) * update release notes * add removed ... line from release notes * add link to PR * remove is_observed usage from TestMarginalVsLatent * remove is_observed usage from TestMarginalVsMarginalSparse * Update RELEASE-NOTES.md Co-authored-by: Thomas Wiecki <[email protected]>
1 parent d74537c commit 64c1464

File tree

6 files changed

+498
-329
lines changed

6 files changed

+498
-329
lines changed

RELEASE-NOTES.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,16 @@ All of the above apply to:
4545
- Changes to the BART implementation:
4646
- A BART variable can be combined with other random variables. The `inv_link` argument has been removed (see [4914](https://github.com/pymc-devs/pymc3/pull/4914)).
4747
- Moved BART to its own module (see [5058](https://github.com/pymc-devs/pymc3/pull/5058)).
48+
- Changes to the Gaussian Process (GP) submodule (see [5055](https://github.com/pymc-devs/pymc/pull/5055)):
49+
- For all implementations, `gp.Latent`, `gp.Marginal` etc., `cov_func` and `mean_func` are required kwargs.
50+
- In Windows test conda environment the `mkl` version is fixed to verison 2020.4, and `mkl-service` is fixed to `2.3.0`. This was required for `gp.MarginalKron` to function properly.
51+
- `gp.MvStudentT` uses rotated samples from `StudentT` directly now, instead of sampling from `pm.Chi2` and then from `pm.Normal`.
52+
- The "jitter" parameter, or the diagonal noise term added to Gram matrices such that the Cholesky is numerically stable, is now exposed to the user instead of hard-coded. See the function `gp.util.stabilize`.
53+
- The `is_observed` arguement for `gp.Marginal*` implementations has been deprecated.
54+
- In the gp.utils file, the `kmeans_inducing_points` function now passes through `kmeans_kwargs` to scipy's k-means function.
55+
- The function `replace_with_values` function has been added to `gp.utils`.
4856
- ...
4957

50-
5158
### Expected breaks
5259
+ New API was already available in `v3`.
5360
+ Old API had deprecation warnings since at least `3.11.0` (2021-01).

conda-envs/windows-environment-test-py38.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ dependencies:
1212
- fastprogress>=0.2.0
1313
- h5py>=2.7
1414
- libpython
15-
- mkl-service
15+
- mkl==2020.4
16+
- mkl-service==2.3.0
1617
- m2w64-toolchain
1718
- numpy>=1.15.0
1819
- pandas

pymc/gp/cov.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, input_dim, active_dims=None):
6464
if active_dims is None:
6565
self.active_dims = np.arange(input_dim)
6666
else:
67-
self.active_dims = np.asarray(active_dims, np.int)
67+
self.active_dims = np.asarray(active_dims, int)
6868

6969
def __call__(self, X, Xs=None, diag=False):
7070
r"""
@@ -152,7 +152,9 @@ def __array_wrap__(self, result):
152152
elif isinstance(result[0][0], Prod):
153153
return result[0][0].factor_list[0] * A
154154
else:
155-
raise RuntimeError
155+
raise TypeError(
156+
f"Unknown Covariance combination type {result[0][0]}. Known types are `Add` or `Prod`."
157+
)
156158

157159

158160
class Combination(Covariance):

0 commit comments

Comments
 (0)