|
10 | 10 | import pytensor.tensor.basic as ptb
|
11 | 11 | import pytensor.tensor.math as ptm
|
12 | 12 | from pytensor import compile, config, function, shared
|
| 13 | +from pytensor.compile import SharedVariable |
13 | 14 | from pytensor.compile.io import In, Out
|
14 | 15 | from pytensor.compile.mode import Mode, get_default_mode
|
15 | 16 | from pytensor.compile.ops import DeepCopyOp
|
@@ -4565,3 +4566,37 @@ def core_np(x):
|
4565 | 4566 | vectorize_pt(x_test),
|
4566 | 4567 | vectorize_np(x_test),
|
4567 | 4568 | )
|
| 4569 | + |
| 4570 | + |
| 4571 | +@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)]) |
| 4572 | +@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"]) |
| 4573 | +@config.change_flags(cxx="") # C code not needed |
| 4574 | +def test_vectorize_join(axis, broadcasting_y): |
| 4575 | + # Signature for join along intermediate axis |
| 4576 | + signature = "(a,b1,c),(a,b2,c)->(a,b,c)" |
| 4577 | + |
| 4578 | + def core_pt(x, y): |
| 4579 | + return join(axis, x, y) |
| 4580 | + |
| 4581 | + def core_np(x, y): |
| 4582 | + return np.concatenate([x, y], axis=axis.eval()) |
| 4583 | + |
| 4584 | + x = tensor(shape=(4, 2, 3, 5)) |
| 4585 | + y_shape = {"none": (4, 2, 3, 5), "implicit": (2, 3, 5), "explicit": (1, 2, 3, 5)} |
| 4586 | + y = tensor(shape=y_shape[broadcasting_y]) |
| 4587 | + |
| 4588 | + vectorize_pt = function([x, y], vectorize(core_pt, signature=signature)(x, y)) |
| 4589 | + |
| 4590 | + blockwise_needed = isinstance(axis, SharedVariable) or broadcasting_y != "none" |
| 4591 | + has_blockwise = any( |
| 4592 | + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes |
| 4593 | + ) |
| 4594 | + assert has_blockwise == blockwise_needed |
| 4595 | + |
| 4596 | + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) |
| 4597 | + y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) |
| 4598 | + vectorize_np = np.vectorize(core_np, signature=signature) |
| 4599 | + np.testing.assert_allclose( |
| 4600 | + vectorize_pt(x_test, y_test), |
| 4601 | + vectorize_np(x_test, y_test), |
| 4602 | + ) |
0 commit comments