Skip to content

Commit d8a6279

Browse files
committed
Merge branch 'main' into ciguaran_integrate_blackjax_smc
2 parents ebd3f9c + 7799595 commit d8a6279

File tree

5 files changed

+59
-28
lines changed

5 files changed

+59
-28
lines changed

pymc/logprob/rewriting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from pytensor.tensor.elemwise import DimShuffle, Elemwise
7171
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
7272
from pytensor.tensor.rewriting.basic import register_canonicalize
73+
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
7374
from pytensor.tensor.rewriting.shape import ShapeFeature
7475
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
7576
from pytensor.tensor.subtensor import (
@@ -359,7 +360,12 @@ def incsubtensor_rv_replace(fgraph, node):
359360

360361
logprob_rewrites_db = SequenceDB()
361362
logprob_rewrites_db.name = "logprob_rewrites_db"
363+
# Introduce sigmoid. We do it before canonicalization so that useless mul are removed next
364+
logprob_rewrites_db.register(
365+
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
366+
)
362367
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
368+
# Split max_and_argmax
363369
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")
364370

365371
# These rewrites convert un-measurable variables into their measurable forms,

pymc/math.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@
3232
from pytensor.tensor import (
3333
abs,
3434
and_,
35+
arccos,
36+
arccosh,
37+
arcsin,
38+
arcsinh,
39+
arctan,
40+
arctanh,
3541
broadcast_to,
3642
ceil,
3743
clip,
3844
concatenate,
3945
constant,
4046
cos,
4147
cosh,
48+
cumprod,
49+
cumsum,
4250
dot,
4351
eq,
4452
erf,
@@ -59,8 +67,10 @@
5967
logsumexp,
6068
lt,
6169
matmul,
70+
max,
6271
maximum,
6372
mean,
73+
min,
6474
minimum,
6575
neq,
6676
ones,
@@ -94,13 +104,21 @@
94104
__all__ = [
95105
"abs",
96106
"and_",
107+
"arccos",
108+
"arccosh",
109+
"arcsin",
110+
"arcsinh",
111+
"arctan",
112+
"arctanh",
97113
"broadcast_to",
98114
"ceil",
99115
"clip",
100116
"concatenate",
101117
"constant",
102118
"cos",
103119
"cosh",
120+
"cumprod",
121+
"cumsum",
104122
"dot",
105123
"eq",
106124
"erf",
@@ -121,14 +139,17 @@
121139
"logsumexp",
122140
"lt",
123141
"matmul",
142+
"max",
124143
"maximum",
125144
"mean",
145+
"min",
126146
"minimum",
127147
"neq",
128148
"ones",
129149
"ones_like",
130150
"or_",
131151
"prod",
152+
"round",
132153
"sgn",
133154
"sigmoid",
134155
"sin",
@@ -258,7 +279,7 @@ def kron_diag(*diags):
258279
return reduce(flat_outer, diags)
259280

260281

261-
def tround(*args, **kwargs):
282+
def round(*args, **kwargs):
262283
"""
263284
Temporary function to silence round warning in PyTensor. Please remove
264285
when the warning disappears.
@@ -267,6 +288,11 @@ def tround(*args, **kwargs):
267288
return pt.round(*args, **kwargs)
268289

269290

291+
def tround(*args, **kwargs):
292+
warnings.warn("tround is deprecated. Use round instead.")
293+
return round(*args, **kwargs)
294+
295+
270296
def logdiffexp(a, b):
271297
"""log(exp(a) - exp(b))"""
272298
return a + pt.log1mexp(b - a)

pymc/model/core.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,15 +1347,6 @@ def make_obs_var(
13471347
"Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
13481348
)
13491349

1350-
if pytensor.config.compute_test_value != "off":
1351-
test_value = getattr(rv_var.tag, "test_value", None)
1352-
1353-
if test_value is not None:
1354-
# We try to reuse the old test value
1355-
rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.shape)
1356-
else:
1357-
rv_var.tag.test_value = data
1358-
13591350
mask = getattr(data, "mask", None)
13601351
if mask is not None:
13611352
impute_message = (

tests/logprob/test_transforms.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,31 +1127,31 @@ def test_cosh_rv_transform():
11271127
)
11281128

11291129

1130-
TRANSFORMATIONS = {
1131-
"log1p": (pt.log1p, lambda x: pt.log(1 + x)),
1132-
"softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
1133-
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
1134-
"log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
1135-
"log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
1136-
"exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
1137-
"expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
1138-
"sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
1139-
}
1140-
1141-
1142-
@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
1143-
def test_special_log_exp_transforms(transform):
1130+
@pytest.mark.parametrize(
1131+
"canonical_func,raw_func",
1132+
[
1133+
(pt.log1p, lambda x: pt.log(1 + x)),
1134+
(pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
1135+
(pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
1136+
(pt.log2, lambda x: pt.log(x) / pt.log(2)),
1137+
(pt.log10, lambda x: pt.log(x) / pt.log(10)),
1138+
(pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
1139+
(pt.expm1, lambda x: pt.exp(x) - 1),
1140+
(pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
1141+
(pt.sigmoid, lambda x: pt.exp(x) / (1 + pt.exp(x))),
1142+
],
1143+
)
1144+
def test_special_log_exp_transforms(canonical_func, raw_func):
11441145
base_rv = pt.random.normal(name="base_rv")
11451146
vv = pt.scalar("vv")
11461147

1147-
transform_func, ref_func = TRANSFORMATIONS[transform]
1148-
transformed_rv = transform_func(base_rv)
1149-
ref_transformed_rv = ref_func(base_rv)
1148+
transformed_rv = raw_func(base_rv)
1149+
ref_transformed_rv = canonical_func(base_rv)
11501150

11511151
logp_test = logp(transformed_rv, vv)
11521152
logp_ref = logp(ref_transformed_rv, vv)
11531153

1154-
if transform in ["log2", "log10"]:
1154+
if canonical_func in (pt.log2, pt.log10):
11551155
# in the cases of log2 and log10 floating point inprecision causes failure
11561156
# from equal_computations so evaluate logp and check all close instead
11571157
vv_test = np.array(0.25)

tests/model/test_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ def test_observed_type(self):
223223
assert x1.type.dtype == X.type.dtype
224224
assert x2.type.dtype == X.type.dtype
225225

226+
@pytensor.config.change_flags(compute_test_value="raise")
227+
def test_observed_compute_test_value(self):
228+
data = np.zeros(100)
229+
with pm.Model():
230+
obs = pm.Normal("obs", mu=pt.zeros_like(data), sigma=1, observed=data)
231+
assert obs.tag.test_value.shape == data.shape
232+
assert obs.tag.test_value.dtype == data.dtype
233+
226234

227235
def test_duplicate_vars():
228236
with pytest.raises(ValueError) as err:

0 commit comments

Comments
 (0)