diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9302d6d67d..8f5c9181d0 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1903,6 +1903,7 @@ class Split(COp): b == [3, 4] c == [5] + TODO: Don't make a copy in C impl """ len_splits = None @@ -1913,6 +1914,7 @@ class Split(COp): def __init__(self, len_splits): self.len_splits = int(len_splits) + self.view_map = {i: [0] for i in range(self.len_splits)} def __str__(self): return f"{self.__class__.__name__ }{{{self.len_splits}}}" @@ -1949,7 +1951,7 @@ def perform(self, node, inputs, outputs): split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis) for i, out in enumerate(split_outs): - outputs[i][0] = out.copy() + outputs[i][0] = out def infer_shape(self, fgraph, node, in_shapes): axis = node.inputs[1] diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 62aacff0a9..047bc18a98 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -4,10 +4,11 @@ import pytensor.scalar as aes import pytensor.tensor as at import pytensor.tensor.basic as atb -from pytensor import config +from pytensor import config, function from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph +from pytensor.scalar import Add from pytensor.tensor.shape import Unbroadcast from tests.link.numba.test_basic import ( compare_numba_and_py, @@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes): ) +def test_Split_view(): + # https://github.com/pymc-devs/pytensor/issues/343 + x1 = at.matrix("x1") + x2 = at.matrix("x2", shape=(None, 1)) + v = at.vector("v", shape=(2,), dtype=int) + out = at.split(x1, v, n_splits=2, axis=1)[0] + x2 + + fn = function([x1, x2, v], out, mode="NUMBA") + # Check that the addition of split[0] and x2 is not in place + add_op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(add_op.scalar_op, Add) + assert not add_op.inplace_pattern + + rng = np.random.default_rng(123) + test_x1 = rng.normal(size=(2, 2)) + test_x2 = rng.normal(size=(2, 1)) + test_v = np.array([1, 1]) + + np.testing.assert_allclose( + fn(test_x1, test_x2, test_v).copy(), + fn(test_x1, test_x2, test_v).copy(), + ) + + @pytest.mark.parametrize( "val, offset", [ diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 06694f6a67..dd7c184073 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1372,15 +1372,28 @@ def test_local_useless_split(): f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4]) f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1]) - graph_rewritten = f_rewritten.maker.fgraph.toposort() - graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort() + graph_rewritten = f_rewritten.maker.fgraph + graph_not_rewritten = f_not_rewritten.maker.fgraph - assert isinstance(graph_rewritten[-1].op, DeepCopyOp) - assert len(graph_not_rewritten) == 1 - assert isinstance(graph_not_rewritten[0].op, Split) + assert all( + isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs + ) + assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs) + + assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0 + assert ( + sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1 + ) + + assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2 + assert ( + sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes) + == 0 + ) + # The DeepCopy Ops don't have traces, so we can't check "all" assert check_stack_trace(f_rewritten, ops_to_check=[Assert]) - assert check_stack_trace(f_not_rewritten, ops_to_check="all") + assert check_stack_trace(f_not_rewritten, ops_to_check=[Split]) @pytest.mark.parametrize("i", list(range(1, 4))) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 804066ebb4..0be24733f1 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -11,7 +11,7 @@ import pytensor.tensor.math as tm from pytensor import compile, config, function, shared from pytensor.compile.io import In, Out -from pytensor.compile.mode import get_default_mode +from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.ops import DeepCopyOp from pytensor.gradient import grad, hessian from pytensor.graph.basic import Apply @@ -2002,45 +2002,65 @@ def test_split_static_shape(self): y = Split(2)(x, 0, [s, 5 - s])[0] assert y.type.shape == (None,) - -def test_join_inplace(): - # Test join to work inplace. - # - # This function tests the case when several elements are passed to the - # join function but all except one of them are empty. In this case join - # should work inplace and the output should be the view of the non-empty - # element. - s = lscalar() - x = vector("x") - z = at.zeros((s,)) - - join = Join(view=0) - c = join(0, x, z, z) - - f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) - - data = np.array([3, 4, 5], dtype=config.floatX) - - if config.mode not in ["DebugMode", "DEBUG_MODE"]: - assert f(data, 0) is data - assert np.allclose(f(data, 0), [3, 4, 5]) - - -def test_join_oneInput(): - # Test join when only 1 input is given. - # - # This functions tests the case when concatenate is called - # on an array of tensors but the array has only one element. - # In this case, we would like to avoid the computational - # overhead of concatenation of one element. - x_0 = fmatrix() - x_1 = fmatrix() - x_2 = fvector() - join_0 = at.concatenate([x_0], axis=1) - join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) - - assert join_0 is x_0 - assert join_1 is not x_0 + def test_join_inplace(self): + # Test join to work inplace. + # + # This function tests the case when several elements are passed to the + # join function but all except one of them are empty. In this case join + # should work inplace and the output should be the view of the non-empty + # element. + s = lscalar() + x = vector("x") + z = at.zeros((s,)) + + join = Join(view=0) + c = join(0, x, z, z) + + f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) + + data = np.array([3, 4, 5], dtype=config.floatX) + + if config.mode not in ["DebugMode", "DEBUG_MODE"]: + assert f(data, 0) is data + assert np.allclose(f(data, 0), [3, 4, 5]) + + def test_join_oneInput(self): + # Test join when only 1 input is given. + # + # This functions tests the case when concatenate is called + # on an array of tensors but the array has only one element. + # In this case, we would like to avoid the computational + # overhead of concatenation of one element. + x_0 = fmatrix() + x_1 = fmatrix() + x_2 = fvector() + join_0 = at.concatenate([x_0], axis=1) + join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) + + assert join_0 is x_0 + assert join_1 is not x_0 + + @pytest.mark.parametrize("linker", ("py", "c")) + def test_split_view(self, linker): + x = vector("x") + axis = 0 + op = Split(len_splits=3) + assert op.view_map == {0: [0], 1: [0], 2: [0]} + splits = op(x, axis, [0, 3, 2]) + + mode = Mode(linker) + f = pytensor.function( + [In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode + ) + x_test = np.arange(5, dtype=config.floatX) + res = f(x_test) + for r, expected in zip(res, ([], [0, 1, 2], [3, 4])): + assert np.allclose(r, expected) + if linker == "py": + assert r.base is x_test + else: + # C impl always makes a copy + assert r.base is not x_test def test_TensorFromScalar():