Skip to content

Commit 450e7f6

Browse files
authored
Support PyTensor deterministic operations as observations (#7656)
1 parent 62335ac commit 450e7f6

File tree

4 files changed

+52
-14
lines changed

4 files changed

+52
-14
lines changed

pymc/data.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,12 @@
3030
from pytensor.compile.sharedvalue import SharedVariable
3131
from pytensor.graph.basic import Variable
3232
from pytensor.raise_op import Assert
33-
from pytensor.scalar import Cast
34-
from pytensor.tensor.elemwise import Elemwise
3533
from pytensor.tensor.random.basic import IntegersRV
3634
from pytensor.tensor.variable import TensorConstant, TensorVariable
3735

3836
import pymc as pm
3937

38+
from pymc.logprob.utils import rvs_in_graph
4039
from pymc.pytensorf import convert_data
4140
from pymc.vartypes import isgenerator
4241

@@ -111,13 +110,7 @@ def is_valid_observed(v) -> bool:
111110
return True
112111

113112
return (
114-
# The only PyTensor operation we allow on observed data is type casting
115-
# Although we could allow for any graph that does not depend on other RVs
116-
(
117-
isinstance(v.owner.op, Elemwise)
118-
and isinstance(v.owner.op.scalar_op, Cast)
119-
and is_valid_observed(v.owner.inputs[0])
120-
)
113+
not rvs_in_graph(v)
121114
# Or Minibatch
122115
or (
123116
isinstance(v.owner.op, MinibatchOp)
@@ -148,7 +141,7 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
148141
for i, v in enumerate(tensors):
149142
if not is_valid_observed(v):
150143
raise ValueError(
151-
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
144+
f"{i}: {v} is not valid for Minibatch, only non-random variables are allowed"
152145
)
153146

154147
upper = tensors[0].shape[0]

pymc/pytensorf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
164164
mask[mask_idx] = 1
165165
return np.ma.MaskedArray(array_data, mask)
166166

167+
from pymc.logprob.utils import rvs_in_graph
168+
169+
if not inputvars(x) and not rvs_in_graph(x):
170+
cheap_eval_mode = Mode(linker="py", optimizer=None)
171+
return x.eval(mode=cheap_eval_mode)
172+
167173
raise TypeError(f"Data cannot be extracted from {x}")
168174

169175

tests/test_data.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,17 @@ def test_allowed(self):
509509
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
510510
assert isinstance(mb.owner.op, MinibatchOp)
511511

512-
with pytest.raises(ValueError, match="not valid for Minibatch"):
513-
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
512+
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
513+
assert isinstance(mb.owner.op, MinibatchOp)
514+
515+
for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20):
516+
assert isinstance(mb.owner.op, MinibatchOp)
514517

515-
with pytest.raises(ValueError, match="not valid for Minibatch"):
516-
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
518+
def test_not_allowed(self):
519+
data = pt.random.normal(loc=self.data, scale=1)
520+
521+
with pytest.raises(ValueError):
522+
pm.Minibatch(data, batch_size=20)
517523

518524
def test_assert(self):
519525
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
@@ -530,3 +536,21 @@ def test_multiple_vars(self):
530536
[draw_mA, draw_mB] = pm.draw([mA, mB])
531537
assert draw_mA.shape == (10,)
532538
np.testing.assert_allclose(draw_mA, -draw_mB)
539+
540+
541+
def test_scaling_data_works_in_likelihood() -> None:
542+
data = np.array([10, 11, 12, 13, 14, 15])
543+
544+
with pm.Model():
545+
target = pm.Data("target", data)
546+
scale = 12
547+
scaled_target = target / scale
548+
mu = pm.Normal("mu", mu=0, sigma=1)
549+
pm.Normal("x", mu=mu, sigma=1, observed=scaled_target)
550+
551+
idata = pm.sample(10, chains=1, tune=100)
552+
553+
np.testing.assert_allclose(
554+
idata.observed_data["x"].values,
555+
data / scale,
556+
)

tests/test_pytensorf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ def test_minibatch_variable(self):
195195
assert isinstance(res, np.ndarray)
196196
np.testing.assert_array_equal(res, y)
197197

198+
def test_pytensor_operations(self):
199+
x = np.array([1, 2, 3])
200+
target = 1 + 3 * pt.as_tensor_variable(x)
201+
202+
res = extract_obs_data(target)
203+
assert isinstance(res, np.ndarray)
204+
np.testing.assert_array_equal(res, np.array([4, 7, 10]))
205+
206+
def test_pytensor_operations_raises(self):
207+
x = pt.scalar("x")
208+
target = 1 + 3 * x
209+
210+
with pytest.raises(TypeError, match="Data cannot be extracted from"):
211+
extract_obs_data(target)
212+
198213

199214
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
200215
def test_convert_data(input_dtype):

0 commit comments

Comments
 (0)