Skip to content

Commit ada6716

Browse files
twieckiClaude
and
Claude
committed
Remove redundant test_matmul
- The `matmul` function was already well-tested elsewhere - Focus our tests specifically on the three new helper functions 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 6f0f14c commit ada6716

File tree

1 file changed

+0
-60
lines changed

1 file changed

+0
-60
lines changed

tests/tensor/test_math.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,66 +2158,6 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
21582158
expected = make_expected(x_val, y_val)
21592159
np.testing.assert_allclose(f(x_val, y_val), expected)
21602160

2161-
def test_matmul(self):
2162-
"""Test matmul function with various input shapes."""
2163-
rng = np.random.default_rng(seed=utt.fetch_seed())
2164-
2165-
# Test matrix-matrix
2166-
x = matrix()
2167-
y = matrix()
2168-
z = matmul(x, y)
2169-
f = function([x, y], z)
2170-
2171-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2172-
y_val = random(4, 5, rng=rng).astype(config.floatX)
2173-
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2174-
2175-
# Test vector-matrix
2176-
x = vector()
2177-
y = matrix()
2178-
z = matmul(x, y)
2179-
f = function([x, y], z)
2180-
2181-
x_val = random(3, rng=rng).astype(config.floatX)
2182-
y_val = random(3, 4, rng=rng).astype(config.floatX)
2183-
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2184-
2185-
# Test matrix-vector
2186-
x = matrix()
2187-
y = vector()
2188-
z = matmul(x, y)
2189-
f = function([x, y], z)
2190-
2191-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2192-
y_val = random(4, rng=rng).astype(config.floatX)
2193-
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2194-
2195-
# Test vector-vector
2196-
x = vector()
2197-
y = vector()
2198-
z = matmul(x, y)
2199-
f = function([x, y], z)
2200-
2201-
x_val = random(3, rng=rng).astype(config.floatX)
2202-
y_val = random(3, rng=rng).astype(config.floatX)
2203-
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2204-
2205-
# Test batched
2206-
x = tensor3()
2207-
y = tensor3()
2208-
z = matmul(x, y)
2209-
f = function([x, y], z)
2210-
2211-
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2212-
y_val = random(2, 4, 5, rng=rng).astype(config.floatX)
2213-
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2214-
2215-
# Test error cases
2216-
x = scalar()
2217-
y = scalar()
2218-
with pytest.raises(ValueError):
2219-
matmul(x, y)
2220-
22212161

22222162
class TestTensordot:
22232163
def TensorDot(self, axes):

0 commit comments

Comments
 (0)