|
11 | 11 | import pytensor.tensor.math as tm
|
12 | 12 | from pytensor import compile, config, function, shared
|
13 | 13 | from pytensor.compile.io import In, Out
|
14 |
| -from pytensor.compile.mode import get_default_mode |
| 14 | +from pytensor.compile.mode import Mode, get_default_mode |
15 | 15 | from pytensor.compile.ops import DeepCopyOp
|
16 | 16 | from pytensor.gradient import grad, hessian
|
17 | 17 | from pytensor.graph.basic import Apply
|
@@ -2002,45 +2002,65 @@ def test_split_static_shape(self):
|
2002 | 2002 | y = Split(2)(x, 0, [s, 5 - s])[0]
|
2003 | 2003 | assert y.type.shape == (None,)
|
2004 | 2004 |
|
2005 |
| - |
2006 |
| -def test_join_inplace(): |
2007 |
| - # Test join to work inplace. |
2008 |
| - # |
2009 |
| - # This function tests the case when several elements are passed to the |
2010 |
| - # join function but all except one of them are empty. In this case join |
2011 |
| - # should work inplace and the output should be the view of the non-empty |
2012 |
| - # element. |
2013 |
| - s = lscalar() |
2014 |
| - x = vector("x") |
2015 |
| - z = at.zeros((s,)) |
2016 |
| - |
2017 |
| - join = Join(view=0) |
2018 |
| - c = join(0, x, z, z) |
2019 |
| - |
2020 |
| - f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) |
2021 |
| - |
2022 |
| - data = np.array([3, 4, 5], dtype=config.floatX) |
2023 |
| - |
2024 |
| - if config.mode not in ["DebugMode", "DEBUG_MODE"]: |
2025 |
| - assert f(data, 0) is data |
2026 |
| - assert np.allclose(f(data, 0), [3, 4, 5]) |
2027 |
| - |
2028 |
| - |
2029 |
| -def test_join_oneInput(): |
2030 |
| - # Test join when only 1 input is given. |
2031 |
| - # |
2032 |
| - # This functions tests the case when concatenate is called |
2033 |
| - # on an array of tensors but the array has only one element. |
2034 |
| - # In this case, we would like to avoid the computational |
2035 |
| - # overhead of concatenation of one element. |
2036 |
| - x_0 = fmatrix() |
2037 |
| - x_1 = fmatrix() |
2038 |
| - x_2 = fvector() |
2039 |
| - join_0 = at.concatenate([x_0], axis=1) |
2040 |
| - join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) |
2041 |
| - |
2042 |
| - assert join_0 is x_0 |
2043 |
| - assert join_1 is not x_0 |
| 2005 | + def test_join_inplace(self): |
| 2006 | + # Test join to work inplace. |
| 2007 | + # |
| 2008 | + # This function tests the case when several elements are passed to the |
| 2009 | + # join function but all except one of them are empty. In this case join |
| 2010 | + # should work inplace and the output should be the view of the non-empty |
| 2011 | + # element. |
| 2012 | + s = lscalar() |
| 2013 | + x = vector("x") |
| 2014 | + z = at.zeros((s,)) |
| 2015 | + |
| 2016 | + join = Join(view=0) |
| 2017 | + c = join(0, x, z, z) |
| 2018 | + |
| 2019 | + f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) |
| 2020 | + |
| 2021 | + data = np.array([3, 4, 5], dtype=config.floatX) |
| 2022 | + |
| 2023 | + if config.mode not in ["DebugMode", "DEBUG_MODE"]: |
| 2024 | + assert f(data, 0) is data |
| 2025 | + assert np.allclose(f(data, 0), [3, 4, 5]) |
| 2026 | + |
| 2027 | + def test_join_oneInput(self): |
| 2028 | + # Test join when only 1 input is given. |
| 2029 | + # |
| 2030 | + # This functions tests the case when concatenate is called |
| 2031 | + # on an array of tensors but the array has only one element. |
| 2032 | + # In this case, we would like to avoid the computational |
| 2033 | + # overhead of concatenation of one element. |
| 2034 | + x_0 = fmatrix() |
| 2035 | + x_1 = fmatrix() |
| 2036 | + x_2 = fvector() |
| 2037 | + join_0 = at.concatenate([x_0], axis=1) |
| 2038 | + join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) |
| 2039 | + |
| 2040 | + assert join_0 is x_0 |
| 2041 | + assert join_1 is not x_0 |
| 2042 | + |
| 2043 | + @pytest.mark.parametrize("linker", ("py", "c")) |
| 2044 | + def test_split_view(self, linker): |
| 2045 | + x = vector("x") |
| 2046 | + axis = 0 |
| 2047 | + op = Split(len_splits=3) |
| 2048 | + assert op.view_map == {0: [0], 1: [0], 2: [0]} |
| 2049 | + splits = op(x, axis, [0, 3, 2]) |
| 2050 | + |
| 2051 | + mode = Mode(linker) |
| 2052 | + f = pytensor.function( |
| 2053 | + [In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode |
| 2054 | + ) |
| 2055 | + x_test = np.arange(5, dtype=config.floatX) |
| 2056 | + res = f(x_test) |
| 2057 | + for r, expected in zip(res, ([], [0, 1, 2], [3, 4])): |
| 2058 | + assert np.allclose(r, expected) |
| 2059 | + if linker == "py": |
| 2060 | + assert r.base is x_test |
| 2061 | + else: |
| 2062 | + # C impl always makes a copy |
| 2063 | + assert r.base is not x_test |
2044 | 2064 |
|
2045 | 2065 |
|
2046 | 2066 | def test_TensorFromScalar():
|
|
0 commit comments