Skip to content

Commit b494851

Browse files
committed
Implement vectorize_node for CheckAndRaise Op
1 parent 89fe939 commit b494851

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

pytensor/raise_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pytensor.gradient import DisconnectedType
88
from pytensor.graph.basic import Apply, Variable
9+
from pytensor.graph.replace import _vectorize_node
910
from pytensor.link.c.op import COp
1011
from pytensor.link.c.params_type import ParamsType
1112
from pytensor.link.c.type import Generic
@@ -198,3 +199,19 @@ def __str__(self):
198199

199200

200201
assert_op = Assert()
202+
203+
204+
@_vectorize_node.register(CheckAndRaise)
205+
def vectorize_check_and_raise(op, node, batch_x, batch_cond):
206+
from pytensor.tensor.extra_ops import broadcast_arrays
207+
from pytensor.tensor.shape import shape_padright
208+
209+
batch_cond_dims = batch_cond.type.ndim
210+
211+
if batch_cond_dims:
212+
# Condition may broadcast batch dims of x
213+
x_core_ndim = node.inputs[0].type.ndim
214+
batch_x, _ = broadcast_arrays(batch_x, shape_padright(batch_cond, x_core_ndim))
215+
batch_cond = batch_cond.all()
216+
217+
return op.make_node(batch_x, batch_cond)

tests/test_raise_op.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor
66
import pytensor.tensor as pt
77
from pytensor.compile.mode import OPT_FAST_RUN, Mode
8+
from pytensor.graph import vectorize_graph
89
from pytensor.graph.basic import Constant, equal_computations
910
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
1011
from pytensor.scalar.basic import ScalarType, float64
@@ -184,3 +185,64 @@ def test_CheckAndRaise_sparse_variable():
184185
a2 = check_and_raise(aspe1, aspe2.sum() > 2)
185186
with pytest.raises(ValueError, match="sparse_check"):
186187
a2.sum().eval()
188+
189+
190+
@pytensor.config.change_flags(cxx="") # For speed-up
191+
def test_vectorize():
192+
floatX = pytensor.config.floatX
193+
x = pt.vector("x")
194+
y = pt.vector("y")
195+
cond = pt.all(y >= 0)
196+
out = assert_op(x, cond)
197+
198+
batch_x = pt.matrix("batch_x", shape=(2, None))
199+
batch_y = pt.matrix("batch_y", shape=(2, None))
200+
201+
test_x = np.arange(3).astype(floatX)
202+
test_y = np.arange(4).astype(floatX)
203+
test_batch_x = np.arange(6).reshape(2, 3).astype(floatX)
204+
test_batch_y = np.arange(8).reshape(2, 4).astype(floatX)
205+
206+
# Only x is batched
207+
vect_out = vectorize_graph(out, {x: batch_x, y: y})
208+
assert vect_out.type.shape == (2, None)
209+
assert isinstance(vect_out.owner.op, CheckAndRaise)
210+
np.testing.assert_array_equal(
211+
vect_out.eval({batch_x: test_batch_x, y: test_y}),
212+
test_batch_x,
213+
)
214+
with pytest.raises(AssertionError):
215+
vect_out.eval({batch_x: test_batch_x, y: -test_y})
216+
217+
# Only y is batched
218+
vect_out = vectorize_graph(out, {x: x, y: batch_y})
219+
assert vect_out.type.shape == (2, None)
220+
assert isinstance(vect_out.owner.op, CheckAndRaise)
221+
np.testing.assert_array_equal(
222+
vect_out.eval({x: test_x, batch_y: test_batch_y}),
223+
np.broadcast_to(test_x, (2, *test_x.shape)),
224+
)
225+
with pytest.raises(AssertionError):
226+
vect_out.eval({x: test_x, batch_y: -test_batch_y})
227+
228+
# Both x, and y are batched
229+
vect_out = vectorize_graph(out, {x: batch_x, y: batch_y})
230+
assert vect_out.type.shape == (2, None)
231+
assert isinstance(vect_out.owner.op, CheckAndRaise)
232+
np.testing.assert_array_equal(
233+
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
234+
test_batch_x,
235+
)
236+
with pytest.raises(AssertionError):
237+
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})
238+
239+
# Both x, and y are batched and broadcast each other
240+
vect_out = vectorize_graph(out, {x: batch_x[:, None, :], y: batch_y[None, :, :]})
241+
assert vect_out.type.shape == (2, 2, None)
242+
assert isinstance(vect_out.owner.op, CheckAndRaise)
243+
np.testing.assert_array_equal(
244+
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
245+
np.broadcast_to(test_batch_x[:, None, :], (2, *test_batch_x.shape)),
246+
)
247+
with pytest.raises(AssertionError):
248+
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})

0 commit comments

Comments
 (0)