|
8 | 8 |
|
9 | 9 | import pytensor
|
10 | 10 | import pytensor.scalar as ps
|
| 11 | +import pytensor.tensor as pt |
11 | 12 | import tests.unittest_tools as utt
|
| 13 | +from pytensor.compile.function import function |
12 | 14 | from pytensor.compile.mode import Mode
|
13 | 15 | from pytensor.configdefaults import config
|
14 | 16 | from pytensor.graph.basic import Apply, Variable
|
@@ -893,6 +895,25 @@ def test_invalid_static_shape(self):
|
893 | 895 | ):
|
894 | 896 | x + y
|
895 | 897 |
|
| 898 | + @pytest.mark.parametrize( |
| 899 | + "shape_x, shape_y, op, np_op", |
| 900 | + [ |
| 901 | + ((3, 5), (7, 1, 3), pt.add, np.add), |
| 902 | + ((2, 3), (1, 4), pt.mul, np.multiply), |
| 903 | + ], |
| 904 | + ) |
| 905 | + def test_outer(self, shape_x, shape_y, op, np_op): |
| 906 | + x = tensor(dtype=np.float64, shape=shape_x) |
| 907 | + y = tensor(dtype=np.float64, shape=shape_y) |
| 908 | + |
| 909 | + z = op.outer(x, y) |
| 910 | + |
| 911 | + f = function([x, y], z) |
| 912 | + x1 = np.ones(shape_x) |
| 913 | + y1 = np.ones(shape_y) |
| 914 | + |
| 915 | + np.testing.assert_array_equal(f(x1, y1), np_op.outer(x1, y1)) |
| 916 | + |
896 | 917 |
|
897 | 918 | def test_not_implemented_elemwise_grad():
|
898 | 919 | # Regression test for unimplemented gradient in an Elemwise Op.
|
|
0 commit comments