Skip to content

Commit f4536c3

Browse files
committed
Make Split a view_op
This allows the outputs to be views of the inputs. The Python and Numba implementation do that, but the C still performs a copy
1 parent 91966e8 commit f4536c3

File tree

4 files changed

+108
-48
lines changed

4 files changed

+108
-48
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,7 @@ class Split(COp):
19031903
b == [3, 4]
19041904
c == [5]
19051905
1906+
TODO: Don't make a copy in C impl
19061907
"""
19071908

19081909
len_splits = None
@@ -1913,6 +1914,7 @@ class Split(COp):
19131914

19141915
def __init__(self, len_splits):
19151916
self.len_splits = int(len_splits)
1917+
self.view_map = {i: [0] for i in range(self.len_splits)}
19161918

19171919
def __str__(self):
19181920
return f"{self.__class__.__name__ }{{{self.len_splits}}}"
@@ -1949,7 +1951,7 @@ def perform(self, node, inputs, outputs):
19491951

19501952
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
19511953
for i, out in enumerate(split_outs):
1952-
outputs[i][0] = out.copy()
1954+
outputs[i][0] = out
19531955

19541956
def infer_shape(self, fgraph, node, in_shapes):
19551957
axis = node.inputs[1]

tests/link/numba/test_tensor_basic.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import pytensor.scalar as aes
55
import pytensor.tensor as at
66
import pytensor.tensor.basic as atb
7-
from pytensor import config
7+
from pytensor import config, function
88
from pytensor.compile.sharedvalue import SharedVariable
99
from pytensor.graph.basic import Constant
1010
from pytensor.graph.fg import FunctionGraph
11+
from pytensor.scalar import Add
1112
from pytensor.tensor.shape import Unbroadcast
1213
from tests.link.numba.test_basic import (
1314
compare_numba_and_py,
@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
332333
)
333334

334335

336+
def test_Split_view():
337+
# https://github.com/pymc-devs/pytensor/issues/343
338+
x1 = at.matrix("x1")
339+
x2 = at.matrix("x2", shape=(None, 1))
340+
v = at.vector("v", shape=(2,), dtype=int)
341+
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2
342+
343+
fn = function([x1, x2, v], out, mode="NUMBA")
344+
# Check that the addition of split[0] and x2 is not in place
345+
add_op = fn.maker.fgraph.outputs[0].owner.op
346+
assert isinstance(add_op.scalar_op, Add)
347+
assert not add_op.inplace_pattern
348+
349+
rng = np.random.default_rng(123)
350+
test_x1 = rng.normal(size=(2, 2))
351+
test_x2 = rng.normal(size=(2, 1))
352+
test_v = np.array([1, 1])
353+
354+
np.testing.assert_allclose(
355+
fn(test_x1, test_x2, test_v).copy(),
356+
fn(test_x1, test_x2, test_v).copy(),
357+
)
358+
359+
335360
@pytest.mark.parametrize(
336361
"val, offset",
337362
[

tests/tensor/rewriting/test_basic.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,15 +1372,28 @@ def test_local_useless_split():
13721372

13731373
f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4])
13741374
f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1])
1375-
graph_rewritten = f_rewritten.maker.fgraph.toposort()
1376-
graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort()
1375+
graph_rewritten = f_rewritten.maker.fgraph
1376+
graph_not_rewritten = f_not_rewritten.maker.fgraph
13771377

1378-
assert isinstance(graph_rewritten[-1].op, DeepCopyOp)
1379-
assert len(graph_not_rewritten) == 1
1380-
assert isinstance(graph_not_rewritten[0].op, Split)
1378+
assert all(
1379+
isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs
1380+
)
1381+
assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs)
1382+
1383+
assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0
1384+
assert (
1385+
sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1
1386+
)
1387+
1388+
assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2
1389+
assert (
1390+
sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes)
1391+
== 0
1392+
)
13811393

1394+
# The DeepCopy Ops don't have traces, so we can't check "all"
13821395
assert check_stack_trace(f_rewritten, ops_to_check=[Assert])
1383-
assert check_stack_trace(f_not_rewritten, ops_to_check="all")
1396+
assert check_stack_trace(f_not_rewritten, ops_to_check=[Split])
13841397

13851398

13861399
@pytest.mark.parametrize("i", list(range(1, 4)))

tests/tensor/test_basic.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor.tensor.math as tm
1212
from pytensor import compile, config, function, shared
1313
from pytensor.compile.io import In, Out
14-
from pytensor.compile.mode import get_default_mode
14+
from pytensor.compile.mode import Mode, get_default_mode
1515
from pytensor.compile.ops import DeepCopyOp
1616
from pytensor.gradient import grad, hessian
1717
from pytensor.graph.basic import Apply
@@ -2002,45 +2002,65 @@ def test_split_static_shape(self):
20022002
y = Split(2)(x, 0, [s, 5 - s])[0]
20032003
assert y.type.shape == (None,)
20042004

2005-
2006-
def test_join_inplace():
2007-
# Test join to work inplace.
2008-
#
2009-
# This function tests the case when several elements are passed to the
2010-
# join function but all except one of them are empty. In this case join
2011-
# should work inplace and the output should be the view of the non-empty
2012-
# element.
2013-
s = lscalar()
2014-
x = vector("x")
2015-
z = at.zeros((s,))
2016-
2017-
join = Join(view=0)
2018-
c = join(0, x, z, z)
2019-
2020-
f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
2021-
2022-
data = np.array([3, 4, 5], dtype=config.floatX)
2023-
2024-
if config.mode not in ["DebugMode", "DEBUG_MODE"]:
2025-
assert f(data, 0) is data
2026-
assert np.allclose(f(data, 0), [3, 4, 5])
2027-
2028-
2029-
def test_join_oneInput():
2030-
# Test join when only 1 input is given.
2031-
#
2032-
# This functions tests the case when concatenate is called
2033-
# on an array of tensors but the array has only one element.
2034-
# In this case, we would like to avoid the computational
2035-
# overhead of concatenation of one element.
2036-
x_0 = fmatrix()
2037-
x_1 = fmatrix()
2038-
x_2 = fvector()
2039-
join_0 = at.concatenate([x_0], axis=1)
2040-
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)
2041-
2042-
assert join_0 is x_0
2043-
assert join_1 is not x_0
2005+
def test_join_inplace(self):
2006+
# Test join to work inplace.
2007+
#
2008+
# This function tests the case when several elements are passed to the
2009+
# join function but all except one of them are empty. In this case join
2010+
# should work inplace and the output should be the view of the non-empty
2011+
# element.
2012+
s = lscalar()
2013+
x = vector("x")
2014+
z = at.zeros((s,))
2015+
2016+
join = Join(view=0)
2017+
c = join(0, x, z, z)
2018+
2019+
f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
2020+
2021+
data = np.array([3, 4, 5], dtype=config.floatX)
2022+
2023+
if config.mode not in ["DebugMode", "DEBUG_MODE"]:
2024+
assert f(data, 0) is data
2025+
assert np.allclose(f(data, 0), [3, 4, 5])
2026+
2027+
def test_join_oneInput(self):
2028+
# Test join when only 1 input is given.
2029+
#
2030+
# This functions tests the case when concatenate is called
2031+
# on an array of tensors but the array has only one element.
2032+
# In this case, we would like to avoid the computational
2033+
# overhead of concatenation of one element.
2034+
x_0 = fmatrix()
2035+
x_1 = fmatrix()
2036+
x_2 = fvector()
2037+
join_0 = at.concatenate([x_0], axis=1)
2038+
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)
2039+
2040+
assert join_0 is x_0
2041+
assert join_1 is not x_0
2042+
2043+
@pytest.mark.parametrize("linker", ("py", "c"))
2044+
def test_split_view(self, linker):
2045+
x = vector("x")
2046+
axis = 0
2047+
op = Split(len_splits=3)
2048+
assert op.view_map == {0: [0], 1: [0], 2: [0]}
2049+
splits = op(x, axis, [0, 3, 2])
2050+
2051+
mode = Mode(linker)
2052+
f = pytensor.function(
2053+
[In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode
2054+
)
2055+
x_test = np.arange(5, dtype=config.floatX)
2056+
res = f(x_test)
2057+
for r, expected in zip(res, ([], [0, 1, 2], [3, 4])):
2058+
assert np.allclose(r, expected)
2059+
if linker == "py":
2060+
assert r.base is x_test
2061+
else:
2062+
# C impl always makes a copy
2063+
assert r.base is not x_test
20442064

20452065

20462066
def test_TensorFromScalar():

0 commit comments

Comments
 (0)