Skip to content

Commit 0286eb8

Browse files
committed
Copy outputs in numba impl of Split
1 parent 91966e8 commit 0286eb8

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def split(tensor, axis, indices):
144144
# Work around for https://github.com/numba/numba/issues/8257
145145
axis = axis % tensor.ndim
146146
axis = numba_basic.to_scalar(axis)
147-
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
147+
# TODO: Add inplace optimization for Split, so we don't need to copy all the time
148+
return [a.copy() for a in np.split(tensor, np.cumsum(indices)[:-1], axis=axis)]
148149

149150
return split
150151

tests/link/numba/test_tensor_basic.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import pytensor.scalar as aes
55
import pytensor.tensor as at
66
import pytensor.tensor.basic as atb
7-
from pytensor import config
7+
from pytensor import config, function
88
from pytensor.compile.sharedvalue import SharedVariable
99
from pytensor.graph.basic import Constant
1010
from pytensor.graph.fg import FunctionGraph
11+
from pytensor.scalar import Add
1112
from pytensor.tensor.shape import Unbroadcast
1213
from tests.link.numba.test_basic import (
1314
compare_numba_and_py,
@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
332333
)
333334

334335

336+
def test_Split_not_inplace():
337+
# https://github.com/pymc-devs/pytensor/issues/343
338+
x1 = at.matrix("x1")
339+
x2 = at.matrix("x2", shape=(None, 1))
340+
v = at.vector("v", shape=(2,), dtype=int)
341+
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2
342+
343+
fn = function([x1, x2, v], out, mode="NUMBA")
344+
# Check that the addition of split[0] and x2 is in place
345+
add_op = fn.maker.fgraph.outputs[0].owner.op
346+
assert isinstance(add_op.scalar_op, Add)
347+
assert add_op.inplace_pattern == {0: 0}
348+
349+
rng = np.random.default_rng(123)
350+
test_x1 = rng.normal(size=(2, 2))
351+
test_x2 = rng.normal(size=(2, 1))
352+
test_v = np.array([1, 1])
353+
354+
np.testing.assert_allclose(
355+
fn(test_x1, test_x2, test_v).copy(),
356+
fn(test_x1, test_x2, test_v).copy(),
357+
)
358+
359+
335360
@pytest.mark.parametrize(
336361
"val, offset",
337362
[

0 commit comments

Comments
 (0)