Skip to content

Commit 8ae14c2

Browse files
committed
Implement vectorize_node for CheckAndRaise Op
1 parent 31a4df6 commit 8ae14c2

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

pytensor/raise_op.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pytensor.gradient import DisconnectedType
88
from pytensor.graph.basic import Apply, Variable
9+
from pytensor.graph.replace import _vectorize_node
910
from pytensor.link.c.op import COp
1011
from pytensor.link.c.params_type import ParamsType
1112
from pytensor.link.c.type import Generic
@@ -198,3 +199,21 @@ def __str__(self):
198199

199200

200201
assert_op = Assert()
202+
203+
204+
@_vectorize_node.register(CheckAndRaise)
205+
def vectorize_check_and_raise(op, node, batch_x, batch_cond):
206+
from pytensor.tensor.extra_ops import broadcast_arrays
207+
from pytensor.tensor.shape import shape_padright
208+
209+
batch_cond_dims = batch_cond.type.ndim
210+
211+
if batch_cond_dims:
212+
out = op(batch_x, batch_cond.all())
213+
# Condition may broadcast batch dims of x
214+
# We broadcast after the Check Op, so it can be removed more easily if not needed
215+
x_core_ndim = node.inputs[0].type.ndim
216+
batch_out, _ = broadcast_arrays(out, shape_padright(batch_cond, x_core_ndim))
217+
return batch_out.owner
218+
else:
219+
return op.make_node(batch_x, batch_cond)

tests/test_raise_op.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import pytensor
66
import pytensor.tensor as pt
77
from pytensor.compile.mode import OPT_FAST_RUN, Mode
8+
from pytensor.graph import vectorize_graph
89
from pytensor.graph.basic import Constant, equal_computations
910
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
1011
from pytensor.scalar.basic import ScalarType, float64
1112
from pytensor.sparse import as_sparse_variable
13+
from pytensor.tensor.basic import second
14+
from pytensor.tensor.elemwise import DimShuffle
1215
from tests import unittest_tools as utt
1316

1417

@@ -184,3 +187,68 @@ def test_CheckAndRaise_sparse_variable():
184187
a2 = check_and_raise(aspe1, aspe2.sum() > 2)
185188
with pytest.raises(ValueError, match="sparse_check"):
186189
a2.sum().eval()
190+
191+
192+
@pytensor.config.change_flags(cxx="") # For speed-up
193+
def test_vectorize():
194+
floatX = pytensor.config.floatX
195+
x = pt.vector("x")
196+
y = pt.vector("y")
197+
cond = pt.all(y >= 0)
198+
out = assert_op(x, cond)
199+
200+
batch_x = pt.matrix("batch_x", shape=(2, None))
201+
batch_y = pt.matrix("batch_y", shape=(2, None))
202+
203+
test_x = np.arange(3).astype(floatX)
204+
test_y = np.arange(4).astype(floatX)
205+
test_batch_x = np.arange(6).reshape(2, 3).astype(floatX)
206+
test_batch_y = np.arange(8).reshape(2, 4).astype(floatX)
207+
208+
# Only x is batched
209+
vect_out = vectorize_graph(out, {x: batch_x, y: y})
210+
assert vect_out.type.shape == (2, None)
211+
assert isinstance(vect_out.owner.op, CheckAndRaise)
212+
np.testing.assert_array_equal(
213+
vect_out.eval({batch_x: test_batch_x, y: test_y}),
214+
test_batch_x,
215+
)
216+
with pytest.raises(AssertionError):
217+
vect_out.eval({batch_x: test_batch_x, y: -test_y})
218+
219+
# Only y is batched
220+
vect_out = vectorize_graph(out, {x: x, y: batch_y})
221+
assert vect_out.type.shape == (2, None)
222+
assert vect_out.owner.op == second # broadcast
223+
assert isinstance(vect_out.owner.inputs[1].owner.op, DimShuffle)
224+
assert isinstance(vect_out.owner.inputs[1].owner.inputs[0].owner.op, CheckAndRaise)
225+
np.testing.assert_array_equal(
226+
vect_out.eval({x: test_x, batch_y: test_batch_y}),
227+
np.broadcast_to(test_x, (2, *test_x.shape)),
228+
)
229+
with pytest.raises(AssertionError):
230+
vect_out.eval({x: test_x, batch_y: -test_batch_y})
231+
232+
# Both x, and y are batched
233+
vect_out = vectorize_graph(out, {x: batch_x, y: batch_y})
234+
assert vect_out.type.shape == (2, None)
235+
assert vect_out.owner.op == second
236+
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
237+
np.testing.assert_array_equal(
238+
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
239+
test_batch_x,
240+
)
241+
with pytest.raises(AssertionError):
242+
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})
243+
244+
# Both x, and y are batched and broadcast each other
245+
vect_out = vectorize_graph(out, {x: batch_x[:, None, :], y: batch_y[None, :, :]})
246+
assert vect_out.type.shape == (2, 2, None)
247+
assert vect_out.owner.op == second
248+
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
249+
np.testing.assert_array_equal(
250+
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
251+
np.broadcast_to(test_batch_x[:, None, :], (2, *test_batch_x.shape)),
252+
)
253+
with pytest.raises(AssertionError):
254+
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})

0 commit comments

Comments
 (0)