Skip to content

Commit 0ef1ffd

Browse files
twieckiclaude
andcommitted
Simplify matrix/vector helper functions
- Remove redundant dimension checks that Blockwise already handles - Streamline test cases while keeping essential coverage - Based on PR feedback from Ricardo 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 12598e3 commit 0ef1ffd

File tree

2 files changed

+0
-66
lines changed

2 files changed

+0
-66
lines changed

pytensor/tensor/math.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4147,11 +4147,6 @@ def vecdot(
41474147
out : ndarray
41484148
The vector dot product of the inputs computed along the specified axes.
41494149
4150-
Raises
4151-
------
4152-
ValueError
4153-
If either input is a scalar value.
4154-
41554150
Notes
41564151
-----
41574152
This is similar to `dot` but with broadcasting. It computes the dot product
@@ -4161,9 +4156,6 @@ def vecdot(
41614156
x1 = as_tensor_variable(x1)
41624157
x2 = as_tensor_variable(x2)
41634158

4164-
if x1.type.ndim == 0 or x2.type.ndim == 0:
4165-
raise ValueError("vecdot operand cannot be scalar")
4166-
41674159
# Handle negative axis
41684160
if axis < 0:
41694161
x1_axis = axis % x1.type.ndim
@@ -4209,12 +4201,6 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42094201
out : ndarray
42104202
The matrix-vector product with shape (..., M).
42114203
4212-
Raises
4213-
------
4214-
ValueError
4215-
If any input is a scalar or if the trailing dimension of x2 does not match
4216-
the second-to-last dimension of x1.
4217-
42184204
Notes
42194205
-----
42204206
This is similar to `matmul` where the second argument is a vector,
@@ -4224,15 +4210,6 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42244210
x1 = as_tensor_variable(x1)
42254211
x2 = as_tensor_variable(x2)
42264212

4227-
if x1.type.ndim == 0 or x2.type.ndim == 0:
4228-
raise ValueError("matvec operand cannot be scalar")
4229-
4230-
if x1.type.ndim < 2:
4231-
raise ValueError("First input to matvec must have at least 2 dimensions")
4232-
4233-
if x2.type.ndim < 1:
4234-
raise ValueError("Second input to matvec must have at least 1 dimension")
4235-
42364213
out = _matrix_vec_prod(x1, x2)
42374214

42384215
if dtype is not None:
@@ -4260,12 +4237,6 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42604237
out : ndarray
42614238
The vector-matrix product with shape (..., N).
42624239
4263-
Raises
4264-
------
4265-
ValueError
4266-
If any input is a scalar or if the last dimension of x1 does not match
4267-
the second-to-last dimension of x2.
4268-
42694240
Notes
42704241
-----
42714242
This is similar to `matmul` where the first argument is a vector,
@@ -4275,15 +4246,6 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42754246
x1 = as_tensor_variable(x1)
42764247
x2 = as_tensor_variable(x2)
42774248

4278-
if x1.type.ndim == 0 or x2.type.ndim == 0:
4279-
raise ValueError("vecmat operand cannot be scalar")
4280-
4281-
if x1.type.ndim < 1:
4282-
raise ValueError("First input to vecmat must have at least 1 dimension")
4283-
4284-
if x2.type.ndim < 2:
4285-
raise ValueError("Second input to vecmat must have at least 2 dimensions")
4286-
42874249
out = _vec_matrix_prod(x1, x2)
42884250

42894251
if dtype is not None:

tests/tensor/test_math.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,12 +2116,6 @@ def test_vecdot(self):
21162116
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
21172117
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))
21182118

2119-
# Test error cases
2120-
x = scalar()
2121-
y = scalar()
2122-
with pytest.raises(ValueError):
2123-
vecdot(x, y)
2124-
21252119
def test_matvec(self):
21262120
"""Test matvec function with various input shapes."""
21272121
rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -2147,17 +2141,6 @@ def test_matvec(self):
21472141
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
21482142
np.testing.assert_allclose(f(x_val, y_val), expected)
21492143

2150-
# Test error cases
2151-
x = vector()
2152-
y = vector()
2153-
with pytest.raises(ValueError):
2154-
matvec(x, y)
2155-
2156-
x = scalar()
2157-
y = vector()
2158-
with pytest.raises(ValueError):
2159-
matvec(x, y)
2160-
21612144
def test_vecmat(self):
21622145
"""Test vecmat function with various input shapes."""
21632146
rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -2183,17 +2166,6 @@ def test_vecmat(self):
21832166
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
21842167
np.testing.assert_allclose(f(x_val, y_val), expected)
21852168

2186-
# Test error cases
2187-
x = matrix()
2188-
y = vector()
2189-
with pytest.raises(ValueError):
2190-
vecmat(x, y)
2191-
2192-
x = scalar()
2193-
y = matrix()
2194-
with pytest.raises(ValueError):
2195-
vecmat(x, y)
2196-
21972169
def test_matmul(self):
21982170
"""Test matmul function with various input shapes."""
21992171
rng = np.random.default_rng(seed=utt.fetch_seed())

0 commit comments

Comments
 (0)