|
2 | 2 | import pytest
|
3 | 3 |
|
4 | 4 | from pytensor.compile.function import function
|
5 |
| -from pytensor.compile.mode import Mode |
6 | 5 | from pytensor.configdefaults import config
|
7 | 6 | from pytensor.graph.fg import FunctionGraph
|
8 |
| -from pytensor.graph.op import get_test_value |
9 |
| -from pytensor.graph.rewriting.db import RewriteDatabaseQuery |
10 |
| -from pytensor.link.jax import JAXLinker |
11 |
| -from pytensor.tensor import blas as pt_blas |
12 | 7 | from pytensor.tensor import nlinalg as pt_nlinalg
|
13 |
| -from pytensor.tensor.math import Argmax, Max, maximum |
14 |
| -from pytensor.tensor.math import max as pt_max |
15 |
| -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector |
| 8 | +from pytensor.tensor.type import matrix |
16 | 9 | from tests.link.jax.test_basic import compare_jax_and_py
|
17 | 10 |
|
18 | 11 |
|
19 | 12 | jax = pytest.importorskip("jax")
|
20 | 13 |
|
21 | 14 |
|
22 |
| -def test_jax_BatchedDot(): |
23 |
| - # tensor3 . tensor3 |
24 |
| - a = tensor3("a") |
25 |
| - a.tag.test_value = ( |
26 |
| - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) |
27 |
| - ) |
28 |
| - b = tensor3("b") |
29 |
| - b.tag.test_value = ( |
30 |
| - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) |
31 |
| - ) |
32 |
| - out = pt_blas.BatchedDot()(a, b) |
33 |
| - fgraph = FunctionGraph([a, b], [out]) |
34 |
| - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
35 |
| - |
36 |
| - # A dimension mismatch should raise a TypeError for compatibility |
37 |
| - inputs = [get_test_value(a)[:-1], get_test_value(b)] |
38 |
| - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) |
39 |
| - jax_mode = Mode(JAXLinker(), opts) |
40 |
| - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) |
41 |
| - with pytest.raises(TypeError): |
42 |
| - pytensor_jax_fn(*inputs) |
43 |
| - |
44 |
| - |
45 | 15 | def test_jax_basic_multiout():
|
46 | 16 | rng = np.random.default_rng(213234)
|
47 | 17 |
|
@@ -79,45 +49,6 @@ def assert_fn(x, y):
|
79 | 49 | compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
|
80 | 50 |
|
81 | 51 |
|
82 |
| -def test_jax_max_and_argmax(): |
83 |
| - # Test that a single output of a multi-output `Op` can be used as input to |
84 |
| - # another `Op` |
85 |
| - x = dvector() |
86 |
| - mx = Max([0])(x) |
87 |
| - amx = Argmax([0])(x) |
88 |
| - out = mx * amx |
89 |
| - out_fg = FunctionGraph([x], [out]) |
90 |
| - compare_jax_and_py(out_fg, [np.r_[1, 2]]) |
91 |
| - |
92 |
| - |
93 |
| -def test_tensor_basics(): |
94 |
| - y = vector("y") |
95 |
| - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) |
96 |
| - x = vector("x") |
97 |
| - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) |
98 |
| - A = matrix("A") |
99 |
| - A.tag.test_value = np.empty((2, 2), dtype=config.floatX) |
100 |
| - alpha = scalar("alpha") |
101 |
| - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) |
102 |
| - beta = scalar("beta") |
103 |
| - beta.tag.test_value = np.array(5.0, dtype=config.floatX) |
104 |
| - |
105 |
| - # This should be converted into a `Gemv` `Op` when the non-JAX compatible |
106 |
| - # optimizations are turned on; however, when using JAX mode, it should |
107 |
| - # leave the expression alone. |
108 |
| - out = y.dot(alpha * A).dot(x) + beta * y |
109 |
| - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) |
110 |
| - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
111 |
| - |
112 |
| - out = maximum(y, x) |
113 |
| - fgraph = FunctionGraph([y, x], [out]) |
114 |
| - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
115 |
| - |
116 |
| - out = pt_max(y) |
117 |
| - fgraph = FunctionGraph([y], [out]) |
118 |
| - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
119 |
| - |
120 |
| - |
121 | 52 | def test_pinv():
|
122 | 53 | x = matrix("x")
|
123 | 54 | x_inv = pt_nlinalg.pinv(x)
|
|
0 commit comments