|
4 | 4 | import pytensor.scalar as aes
|
5 | 5 | import pytensor.tensor as at
|
6 | 6 | import pytensor.tensor.basic as atb
|
7 |
| -from pytensor import config |
| 7 | +from pytensor import config, function |
8 | 8 | from pytensor.compile.sharedvalue import SharedVariable
|
9 | 9 | from pytensor.graph.basic import Constant
|
10 | 10 | from pytensor.graph.fg import FunctionGraph
|
| 11 | +from pytensor.scalar import Add |
11 | 12 | from pytensor.tensor.shape import Unbroadcast
|
12 | 13 | from tests.link.numba.test_basic import (
|
13 | 14 | compare_numba_and_py,
|
@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
|
332 | 333 | )
|
333 | 334 |
|
334 | 335 |
|
| 336 | +def test_Split_not_inplace(): |
| 337 | + # https://github.com/pymc-devs/pytensor/issues/343 |
| 338 | + x1 = at.matrix("x1") |
| 339 | + x2 = at.matrix("x2", shape=(None, 1)) |
| 340 | + v = at.vector("v", shape=(2,), dtype=int) |
| 341 | + out = at.split(x1, v, n_splits=2, axis=1)[0] + x2 |
| 342 | + |
| 343 | + fn = function([x1, x2, v], out, mode="NUMBA") |
| 344 | + # Check that the addition of split[0] and x2 is in place |
| 345 | + add_op = fn.maker.fgraph.outputs[0].owner.op |
| 346 | + assert isinstance(add_op.scalar_op, Add) |
| 347 | + assert add_op.inplace_pattern == {0: 0} |
| 348 | + |
| 349 | + rng = np.random.default_rng(123) |
| 350 | + test_x1 = rng.normal(size=(2, 2)) |
| 351 | + test_x2 = rng.normal(size=(2, 1)) |
| 352 | + test_v = np.array([1, 1]) |
| 353 | + |
| 354 | + np.testing.assert_allclose( |
| 355 | + fn(test_x1, test_x2, test_v).copy(), |
| 356 | + fn(test_x1, test_x2, test_v).copy(), |
| 357 | + ) |
| 358 | + |
| 359 | + |
335 | 360 | @pytest.mark.parametrize(
|
336 | 361 | "val, offset",
|
337 | 362 | [
|
|
0 commit comments