Skip to content

Commit e29bea4

Browse files
twieckiClaude
and
Claude
committed
Address PR feedback for matrix-vector operations
- Improve docstrings with concrete shape examples - Explicitly state equivalence to NumPy functions - Simplify tests into a single parametrized test - Add dtype parameter test to ensure full coverage - Keep implementation minimal by relying on Blockwise checks 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent ada6716 commit e29bea4

File tree

2 files changed

+50
-57
lines changed

2 files changed

+50
-57
lines changed

pytensor/tensor/math.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4145,17 +4145,24 @@ def vecdot(
41454145
41464146
Notes
41474147
-----
4148-
This is similar to `np.vecdot` and computes the dot product of
4148+
This is equivalent to `numpy.vecdot` and computes the dot product of
41494149
vectors along the last axis of both inputs. Broadcasting is supported
41504150
across all other dimensions.
41514151
41524152
Examples
41534153
--------
41544154
>>> import pytensor.tensor as pt
4155-
>>> x = pt.matrix("x")
4156-
>>> y = pt.matrix("y")
4157-
>>> z = pt.vecdot(x, y)
4158-
>>> # Equivalent to np.sum(x * y, axis=-1)
4155+
>>> # Vector dot product with shape (5,) inputs
4156+
>>> x = pt.vector("x") # shape (5,)
4157+
>>> y = pt.vector("y") # shape (5,)
4158+
>>> z = pt.vecdot(x, y) # scalar output
4159+
>>> # Equivalent to numpy.vecdot(x, y) or numpy.sum(x * y)
4160+
>>>
4161+
>>> # With batched inputs of shape (3, 5)
4162+
>>> x_batch = pt.matrix("x") # shape (3, 5)
4163+
>>> y_batch = pt.matrix("y") # shape (3, 5)
4164+
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
4165+
>>> # Equivalent to numpy.sum(x_batch * y_batch, axis=-1)
41594166
"""
41604167
x1 = as_tensor_variable(x1)
41614168
x2 = as_tensor_variable(x2)
@@ -4199,17 +4206,16 @@ def matvec(
41994206
Examples
42004207
--------
42014208
>>> import pytensor.tensor as pt
4202-
>>> import numpy as np
42034209
>>> # Matrix-vector product
4204-
>>> A = pt.matrix("A") # shape (M, K)
4205-
>>> v = pt.vector("v") # shape (K,)
4206-
>>> result = pt.matvec(A, v) # shape (M,)
4207-
>>> # Equivalent to np.matmul(A, v)
4210+
>>> A = pt.matrix("A") # shape (3, 4)
4211+
>>> v = pt.vector("v") # shape (4,)
4212+
>>> result = pt.matvec(A, v) # shape (3,)
4213+
>>> # Equivalent to numpy.matmul(A, v)
42084214
>>>
42094215
>>> # Batched matrix-vector product
4210-
>>> batched_A = pt.tensor3("A") # shape (B, M, K)
4211-
>>> batched_v = pt.matrix("v") # shape (B, K)
4212-
>>> result = pt.matvec(batched_A, batched_v) # shape (B, M)
4216+
>>> batched_A = pt.tensor3("A") # shape (2, 3, 4)
4217+
>>> batched_v = pt.matrix("v") # shape (2, 4)
4218+
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
42134219
"""
42144220
x1 = as_tensor_variable(x1)
42154221
x2 = as_tensor_variable(x2)
@@ -4252,17 +4258,16 @@ def vecmat(
42524258
Examples
42534259
--------
42544260
>>> import pytensor.tensor as pt
4255-
>>> import numpy as np
42564261
>>> # Vector-matrix product
4257-
>>> v = pt.vector("v") # shape (K,)
4258-
>>> A = pt.matrix("A") # shape (K, N)
4259-
>>> result = pt.vecmat(v, A) # shape (N,)
4260-
>>> # Equivalent to np.matmul(v, A)
4262+
>>> v = pt.vector("v") # shape (3,)
4263+
>>> A = pt.matrix("A") # shape (3, 4)
4264+
>>> result = pt.vecmat(v, A) # shape (4,)
4265+
>>> # Equivalent to numpy.matmul(v, A)
42614266
>>>
42624267
>>> # Batched vector-matrix product
4263-
>>> batched_v = pt.matrix("v") # shape (B, K)
4264-
>>> batched_A = pt.tensor3("A") # shape (B, K, N)
4265-
>>> result = pt.vecmat(batched_v, batched_A) # shape (B, N)
4268+
>>> batched_v = pt.matrix("v") # shape (2, 3)
4269+
>>> batched_A = pt.tensor3("A") # shape (2, 3, 4)
4270+
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
42664271
"""
42674272
x1 = as_tensor_variable(x1)
42684273
x2 = as_tensor_variable(x2)

tests/tensor/test_math.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,55 +2080,36 @@ def is_super_shape(var1, var2):
20802080

20812081

20822082
class TestMatrixVectorOps:
2083-
def test_vecdot(self):
2084-
"""Test vecdot function with various input shapes."""
2085-
rng = np.random.default_rng(seed=utt.fetch_seed())
2086-
2087-
# Test vector-vector
2088-
x = vector()
2089-
y = vector()
2090-
z = vecdot(x, y)
2091-
f = function([x, y], z)
2092-
x_val = random(5, rng=rng).astype(config.floatX)
2093-
y_val = random(5, rng=rng).astype(config.floatX)
2094-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2095-
2096-
# Test batched vectors
2097-
x = tensor3()
2098-
y = tensor3()
2099-
z = vecdot(x, y)
2100-
f = function([x, y], z)
2101-
2102-
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2103-
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2104-
expected = np.sum(x_val * y_val, axis=-1)
2105-
np.testing.assert_allclose(f(x_val, y_val), expected)
2083+
"""Test vecdot, matvec, and vecmat helper functions."""
21062084

21072085
@pytest.mark.parametrize(
2108-
"func,x_shape,y_shape,make_expected",
2086+
"func,x_shape,y_shape,np_func,batch_axis",
21092087
[
2110-
# matvec tests - Matrix(M,K) @ Vector(K) -> Vector(M)
2111-
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y)),
2112-
# matvec batch tests - Tensor3(B,M,K) @ Matrix(B,K) -> Matrix(B,M)
2088+
# vecdot
2089+
(vecdot, (5,), (5,), lambda x, y: np.dot(x, y), None),
2090+
(vecdot, (3, 5), (3, 5), lambda x, y: np.sum(x * y, axis=-1), -1),
2091+
# matvec
2092+
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y), None),
21132093
(
21142094
matvec,
21152095
(2, 3, 4),
21162096
(2, 4),
21172097
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2098+
0,
21182099
),
2119-
# vecmat tests - Vector(K) @ Matrix(K,N) -> Vector(N)
2120-
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y)),
2121-
# vecmat batch tests - Matrix(B,K) @ Tensor3(B,K,N) -> Matrix(B,N)
2100+
# vecmat
2101+
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y), None),
21222102
(
21232103
vecmat,
21242104
(2, 3),
21252105
(2, 3, 4),
21262106
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2107+
0,
21272108
),
21282109
],
21292110
)
2130-
def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2131-
"""Parametrized test for matvec and vecmat functions."""
2111+
def test_matrix_vector_ops(self, func, x_shape, y_shape, np_func, batch_axis):
2112+
"""Test all matrix-vector helper functions."""
21322113
rng = np.random.default_rng(seed=utt.fetch_seed())
21332114

21342115
# Create PyTensor variables with appropriate dimensions
@@ -2146,18 +2127,25 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
21462127
else:
21472128
y = tensor3()
21482129

2149-
# Apply the function
2130+
# Test basic functionality
21502131
z = func(x, y)
21512132
f = function([x, y], z)
21522133

2153-
# Create random values
21542134
x_val = random(*x_shape, rng=rng).astype(config.floatX)
21552135
y_val = random(*y_shape, rng=rng).astype(config.floatX)
21562136

2157-
# Compare with the expected result
2158-
expected = make_expected(x_val, y_val)
2137+
expected = np_func(x_val, y_val)
21592138
np.testing.assert_allclose(f(x_val, y_val), expected)
21602139

2140+
# Test with dtype parameter (to improve code coverage)
2141+
# Use float64 to ensure we can detect the difference
2142+
z_dtype = func(x, y, dtype="float64")
2143+
f_dtype = function([x, y], z_dtype)
2144+
2145+
result = f_dtype(x_val, y_val)
2146+
assert result.dtype == np.float64
2147+
np.testing.assert_allclose(result, expected)
2148+
21612149

21622150
class TestTensordot:
21632151
def TensorDot(self, axes):

0 commit comments

Comments
 (0)