@@ -2158,66 +2158,6 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2158
2158
expected = make_expected (x_val , y_val )
2159
2159
np .testing .assert_allclose (f (x_val , y_val ), expected )
2160
2160
2161
- def test_matmul (self ):
2162
- """Test matmul function with various input shapes."""
2163
- rng = np .random .default_rng (seed = utt .fetch_seed ())
2164
-
2165
- # Test matrix-matrix
2166
- x = matrix ()
2167
- y = matrix ()
2168
- z = matmul (x , y )
2169
- f = function ([x , y ], z )
2170
-
2171
- x_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2172
- y_val = random (4 , 5 , rng = rng ).astype (config .floatX )
2173
- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2174
-
2175
- # Test vector-matrix
2176
- x = vector ()
2177
- y = matrix ()
2178
- z = matmul (x , y )
2179
- f = function ([x , y ], z )
2180
-
2181
- x_val = random (3 , rng = rng ).astype (config .floatX )
2182
- y_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2183
- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2184
-
2185
- # Test matrix-vector
2186
- x = matrix ()
2187
- y = vector ()
2188
- z = matmul (x , y )
2189
- f = function ([x , y ], z )
2190
-
2191
- x_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2192
- y_val = random (4 , rng = rng ).astype (config .floatX )
2193
- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2194
-
2195
- # Test vector-vector
2196
- x = vector ()
2197
- y = vector ()
2198
- z = matmul (x , y )
2199
- f = function ([x , y ], z )
2200
-
2201
- x_val = random (3 , rng = rng ).astype (config .floatX )
2202
- y_val = random (3 , rng = rng ).astype (config .floatX )
2203
- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2204
-
2205
- # Test batched
2206
- x = tensor3 ()
2207
- y = tensor3 ()
2208
- z = matmul (x , y )
2209
- f = function ([x , y ], z )
2210
-
2211
- x_val = random (2 , 3 , 4 , rng = rng ).astype (config .floatX )
2212
- y_val = random (2 , 4 , 5 , rng = rng ).astype (config .floatX )
2213
- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2214
-
2215
- # Test error cases
2216
- x = scalar ()
2217
- y = scalar ()
2218
- with pytest .raises (ValueError ):
2219
- matmul (x , y )
2220
-
2221
2161
2222
2162
class TestTensordot :
2223
2163
def TensorDot (self , axes ):
0 commit comments