Skip to content

Commit 2901888

Browse files
Remove uses of deprecated aesara.tensor.nlinalg.diag (#4501)
Co-authored-by: Brandon T. Willard <[email protected]>
1 parent cf662c9 commit 2901888

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

pymc3/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def MvNormalLogp():
275275
n, k = delta.shape
276276
n, k = f(n), f(k)
277277
chol_cov = cholesky(cov)
278-
diag = aet.nlinalg.diag(chol_cov)
278+
diag = aet.diag(chol_cov)
279279
ok = aet.all(diag > 0)
280280

281281
chol_cov = aet.switch(ok, chol_cov, aet.fill(chol_cov, 1))
@@ -295,7 +295,7 @@ def dlogp(inputs, gradients):
295295
n, k = delta.shape
296296

297297
chol_cov = cholesky(cov)
298-
diag = aet.nlinalg.diag(chol_cov)
298+
diag = aet.diag(chol_cov)
299299
ok = aet.all(diag > 0)
300300

301301
chol_cov = aet.switch(ok, chol_cov, aet.fill(chol_cov, 1))

pymc3/distributions/multivariate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _quaddist(self, value):
143143

144144
def _quaddist_chol(self, delta):
145145
chol_cov = self.chol_cov
146-
diag = aet.nlinalg.diag(chol_cov)
146+
diag = aet.diag(chol_cov)
147147
# Check if the covariance matrix is positive definite.
148148
ok = aet.all(diag > 0)
149149
# If not, replace the diagonal. We return -inf later, but
@@ -160,7 +160,7 @@ def _quaddist_cov(self, delta):
160160

161161
def _quaddist_tau(self, delta):
162162
chol_tau = self.chol_tau
163-
diag = aet.nlinalg.diag(chol_tau)
163+
diag = aet.diag(chol_tau)
164164
# Check if the precision matrix is positive definite.
165165
ok = aet.all(diag > 0)
166166
# If not, replace the diagonal. We return -inf later, but
@@ -1668,7 +1668,7 @@ class MatrixNormal(Continuous):
16681668
16691669
# Setup left covariance matrix
16701670
scale = pm.Lognormal('scale', mu=np.log(true_scale), sigma=0.5)
1671-
rowcov = aet.nlinalg.diag([scale**(2*i) for i in range(m)])
1671+
rowcov = aet.diag([scale**(2*i) for i in range(m)])
16721672
16731673
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
16741674
observed=data, shape=(m, n))
@@ -1813,8 +1813,8 @@ def _trquaddist(self, value):
18131813
quaddist = self.solve_upper(colchol_cov.T, quaddist)
18141814
trquaddist = aet.nlinalg.trace(quaddist)
18151815

1816-
coldiag = aet.nlinalg.diag(colchol_cov)
1817-
rowdiag = aet.nlinalg.diag(rowchol_cov)
1816+
coldiag = aet.diag(colchol_cov)
1817+
rowdiag = aet.diag(rowchol_cov)
18181818
half_collogdet = aet.sum(aet.log(coldiag)) # logdet(M) = 2*Tr(log(L))
18191819
half_rowlogdet = aet.sum(aet.log(rowdiag)) # Using Cholesky: M = L L^T
18201820
return trquaddist, half_collogdet, half_rowlogdet
@@ -1958,7 +1958,7 @@ def _setup(self, covs, chols, evds, sigma):
19581958
else:
19591959
# Otherwise use cholesky as usual
19601960
self.chols = list(map(self.cholesky, self.covs))
1961-
self.chol_diags = list(map(aet.nlinalg.diag, self.chols))
1961+
self.chol_diags = list(map(aet.diag, self.chols))
19621962
self.sizes = aet.as_tensor_variable([chol.shape[0] for chol in self.chols])
19631963
self.N = aet.prod(self.sizes)
19641964
elif chols is not None:
@@ -1970,7 +1970,7 @@ def _setup(self, covs, chols, evds, sigma):
19701970
self._setup_evd(eigh_map)
19711971
else:
19721972
self.chols = chols
1973-
self.chol_diags = list(map(aet.nlinalg.diag, self.chols))
1973+
self.chol_diags = list(map(aet.diag, self.chols))
19741974
self.sizes = aet.as_tensor_variable([chol.shape[0] for chol in self.chols])
19751975
self.N = aet.prod(self.sizes)
19761976
else:

pymc3/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@
7070
where,
7171
zeros_like,
7272
)
73-
from aesara.tensor.nlinalg import det, extract_diag, matrix_dot, matrix_inverse, trace
73+
74+
try:
75+
from aesara.tensor.basic import extract_diag
76+
except ImportError:
77+
from aesara.tensor.nlinalg import extract_diag
78+
79+
80+
from aesara.tensor.nlinalg import det, matrix_dot, matrix_inverse, trace
7481
from aesara.tensor.nnet import sigmoid
7582
from scipy.linalg import block_diag as scipy_block_diag
7683

0 commit comments

Comments
 (0)