Skip to content

Make Split a view op #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,7 @@ class Split(COp):
b == [3, 4]
c == [5]

TODO: Don't make a copy in C impl
"""

len_splits = None
Expand All @@ -1913,6 +1914,7 @@ class Split(COp):

def __init__(self, len_splits):
self.len_splits = int(len_splits)
self.view_map = {i: [0] for i in range(self.len_splits)}

def __str__(self):
return f"{self.__class__.__name__ }{{{self.len_splits}}}"
Expand Down Expand Up @@ -1949,7 +1951,7 @@ def perform(self, node, inputs, outputs):

split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs):
outputs[i][0] = out.copy()
outputs[i][0] = out

def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1]
Expand Down
27 changes: 26 additions & 1 deletion tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytensor.scalar as aes
import pytensor.tensor as at
import pytensor.tensor.basic as atb
from pytensor import config
from pytensor import config, function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
compare_numba_and_py,
Expand Down Expand Up @@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
)


def test_Split_view():
# https://github.com/pymc-devs/pytensor/issues/343
x1 = at.matrix("x1")
x2 = at.matrix("x2", shape=(None, 1))
v = at.vector("v", shape=(2,), dtype=int)
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2

fn = function([x1, x2, v], out, mode="NUMBA")
# Check that the addition of split[0] and x2 is not in place
add_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(add_op.scalar_op, Add)
assert not add_op.inplace_pattern

rng = np.random.default_rng(123)
test_x1 = rng.normal(size=(2, 2))
test_x2 = rng.normal(size=(2, 1))
test_v = np.array([1, 1])

np.testing.assert_allclose(
fn(test_x1, test_x2, test_v).copy(),
fn(test_x1, test_x2, test_v).copy(),
)


@pytest.mark.parametrize(
"val, offset",
[
Expand Down
25 changes: 19 additions & 6 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,15 +1372,28 @@ def test_local_useless_split():

f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4])
f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1])
graph_rewritten = f_rewritten.maker.fgraph.toposort()
graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort()
graph_rewritten = f_rewritten.maker.fgraph
graph_not_rewritten = f_not_rewritten.maker.fgraph

assert isinstance(graph_rewritten[-1].op, DeepCopyOp)
assert len(graph_not_rewritten) == 1
assert isinstance(graph_not_rewritten[0].op, Split)
assert all(
isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs
)
assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs)

assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0
assert (
sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1
)

assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2
assert (
sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes)
== 0
)

# The DeepCopy Ops don't have traces, so we can't check "all"
assert check_stack_trace(f_rewritten, ops_to_check=[Assert])
assert check_stack_trace(f_not_rewritten, ops_to_check="all")
assert check_stack_trace(f_not_rewritten, ops_to_check=[Split])


@pytest.mark.parametrize("i", list(range(1, 4)))
Expand Down
100 changes: 60 additions & 40 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytensor.tensor.math as tm
from pytensor import compile, config, function, shared
from pytensor.compile.io import In, Out
from pytensor.compile.mode import get_default_mode
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -2002,45 +2002,65 @@ def test_split_static_shape(self):
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)


def test_join_inplace():
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = at.zeros((s,))

join = Join(view=0)
c = join(0, x, z, z)

f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))

data = np.array([3, 4, 5], dtype=config.floatX)

if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])


def test_join_oneInput():
# Test join when only 1 input is given.
#
# This functions tests the case when concatenate is called
# on an array of tensors but the array has only one element.
# In this case, we would like to avoid the computational
# overhead of concatenation of one element.
x_0 = fmatrix()
x_1 = fmatrix()
x_2 = fvector()
join_0 = at.concatenate([x_0], axis=1)
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)

assert join_0 is x_0
assert join_1 is not x_0
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = at.zeros((s,))

join = Join(view=0)
c = join(0, x, z, z)

f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))

data = np.array([3, 4, 5], dtype=config.floatX)

if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])

def test_join_oneInput(self):
# Test join when only 1 input is given.
#
# This functions tests the case when concatenate is called
# on an array of tensors but the array has only one element.
# In this case, we would like to avoid the computational
# overhead of concatenation of one element.
x_0 = fmatrix()
x_1 = fmatrix()
x_2 = fvector()
join_0 = at.concatenate([x_0], axis=1)
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)

assert join_0 is x_0
assert join_1 is not x_0

@pytest.mark.parametrize("linker", ("py", "c"))
def test_split_view(self, linker):
x = vector("x")
axis = 0
op = Split(len_splits=3)
assert op.view_map == {0: [0], 1: [0], 2: [0]}
splits = op(x, axis, [0, 3, 2])

mode = Mode(linker)
f = pytensor.function(
[In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode
)
x_test = np.arange(5, dtype=config.floatX)
res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4])):
assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test


def test_TensorFromScalar():
Expand Down