Skip to content

Commit 52e07bd

Browse files
twieckiClaude
and
Claude
committed
Address PR feedback for matrix-vector operations
- Improve docstrings with concrete shape examples - Explicitly state equivalence to NumPy functions - Simplify tests into a single parametrized test - Add dtype parameter test to ensure full coverage - Keep implementation minimal by relying on Blockwise checks 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 1ad46be commit 52e07bd

File tree

2 files changed

+50
-57
lines changed

2 files changed

+50
-57
lines changed

pytensor/tensor/math.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,17 +2864,24 @@ def vecdot(
28642864
28652865
Notes
28662866
-----
2867-
This is similar to `np.vecdot` and computes the dot product of
2867+
This is equivalent to `numpy.vecdot` and computes the dot product of
28682868
vectors along the last axis of both inputs. Broadcasting is supported
28692869
across all other dimensions.
28702870
28712871
Examples
28722872
--------
28732873
>>> import pytensor.tensor as pt
2874-
>>> x = pt.matrix("x")
2875-
>>> y = pt.matrix("y")
2876-
>>> z = pt.vecdot(x, y)
2877-
>>> # Equivalent to np.sum(x * y, axis=-1)
2874+
>>> # Vector dot product with shape (5,) inputs
2875+
>>> x = pt.vector("x") # shape (5,)
2876+
>>> y = pt.vector("y") # shape (5,)
2877+
>>> z = pt.vecdot(x, y) # scalar output
2878+
>>> # Equivalent to numpy.vecdot(x, y) or numpy.sum(x * y)
2879+
>>>
2880+
>>> # With batched inputs of shape (3, 5)
2881+
>>> x_batch = pt.matrix("x") # shape (3, 5)
2882+
>>> y_batch = pt.matrix("y") # shape (3, 5)
2883+
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
2884+
>>> # Equivalent to numpy.sum(x_batch * y_batch, axis=-1)
28782885
"""
28792886
x1 = as_tensor_variable(x1)
28802887
x2 = as_tensor_variable(x2)
@@ -2918,17 +2925,16 @@ def matvec(
29182925
Examples
29192926
--------
29202927
>>> import pytensor.tensor as pt
2921-
>>> import numpy as np
29222928
>>> # Matrix-vector product
2923-
>>> A = pt.matrix("A") # shape (M, K)
2924-
>>> v = pt.vector("v") # shape (K,)
2925-
>>> result = pt.matvec(A, v) # shape (M,)
2926-
>>> # Equivalent to np.matmul(A, v)
2929+
>>> A = pt.matrix("A") # shape (3, 4)
2930+
>>> v = pt.vector("v") # shape (4,)
2931+
>>> result = pt.matvec(A, v) # shape (3,)
2932+
>>> # Equivalent to numpy.matmul(A, v)
29272933
>>>
29282934
>>> # Batched matrix-vector product
2929-
>>> batched_A = pt.tensor3("A") # shape (B, M, K)
2930-
>>> batched_v = pt.matrix("v") # shape (B, K)
2931-
>>> result = pt.matvec(batched_A, batched_v) # shape (B, M)
2935+
>>> batched_A = pt.tensor3("A") # shape (2, 3, 4)
2936+
>>> batched_v = pt.matrix("v") # shape (2, 4)
2937+
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
29322938
"""
29332939
x1 = as_tensor_variable(x1)
29342940
x2 = as_tensor_variable(x2)
@@ -2971,17 +2977,16 @@ def vecmat(
29712977
Examples
29722978
--------
29732979
>>> import pytensor.tensor as pt
2974-
>>> import numpy as np
29752980
>>> # Vector-matrix product
2976-
>>> v = pt.vector("v") # shape (K,)
2977-
>>> A = pt.matrix("A") # shape (K, N)
2978-
>>> result = pt.vecmat(v, A) # shape (N,)
2979-
>>> # Equivalent to np.matmul(v, A)
2981+
>>> v = pt.vector("v") # shape (3,)
2982+
>>> A = pt.matrix("A") # shape (3, 4)
2983+
>>> result = pt.vecmat(v, A) # shape (4,)
2984+
>>> # Equivalent to numpy.matmul(v, A)
29802985
>>>
29812986
>>> # Batched vector-matrix product
2982-
>>> batched_v = pt.matrix("v") # shape (B, K)
2983-
>>> batched_A = pt.tensor3("A") # shape (B, K, N)
2984-
>>> result = pt.vecmat(batched_v, batched_A) # shape (B, N)
2987+
>>> batched_v = pt.matrix("v") # shape (2, 3)
2988+
>>> batched_A = pt.tensor3("A") # shape (2, 3, 4)
2989+
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
29852990
"""
29862991
x1 = as_tensor_variable(x1)
29872992
x2 = as_tensor_variable(x2)

tests/tensor/test_math.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,55 +2080,36 @@ def is_super_shape(var1, var2):
20802080

20812081

20822082
class TestMatrixVectorOps:
2083-
def test_vecdot(self):
2084-
"""Test vecdot function with various input shapes."""
2085-
rng = np.random.default_rng(seed=utt.fetch_seed())
2086-
2087-
# Test vector-vector
2088-
x = vector()
2089-
y = vector()
2090-
z = vecdot(x, y)
2091-
f = function([x, y], z)
2092-
x_val = random(5, rng=rng).astype(config.floatX)
2093-
y_val = random(5, rng=rng).astype(config.floatX)
2094-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2095-
2096-
# Test batched vectors
2097-
x = tensor3()
2098-
y = tensor3()
2099-
z = vecdot(x, y)
2100-
f = function([x, y], z)
2101-
2102-
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2103-
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2104-
expected = np.sum(x_val * y_val, axis=-1)
2105-
np.testing.assert_allclose(f(x_val, y_val), expected)
2083+
"""Test vecdot, matvec, and vecmat helper functions."""
21062084

21072085
@pytest.mark.parametrize(
2108-
"func,x_shape,y_shape,make_expected",
2086+
"func,x_shape,y_shape,np_func,batch_axis",
21092087
[
2110-
# matvec tests - Matrix(M,K) @ Vector(K) -> Vector(M)
2111-
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y)),
2112-
# matvec batch tests - Tensor3(B,M,K) @ Matrix(B,K) -> Matrix(B,M)
2088+
# vecdot
2089+
(vecdot, (5,), (5,), lambda x, y: np.dot(x, y), None),
2090+
(vecdot, (3, 5), (3, 5), lambda x, y: np.sum(x * y, axis=-1), -1),
2091+
# matvec
2092+
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y), None),
21132093
(
21142094
matvec,
21152095
(2, 3, 4),
21162096
(2, 4),
21172097
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2098+
0,
21182099
),
2119-
# vecmat tests - Vector(K) @ Matrix(K,N) -> Vector(N)
2120-
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y)),
2121-
# vecmat batch tests - Matrix(B,K) @ Tensor3(B,K,N) -> Matrix(B,N)
2100+
# vecmat
2101+
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y), None),
21222102
(
21232103
vecmat,
21242104
(2, 3),
21252105
(2, 3, 4),
21262106
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2107+
0,
21272108
),
21282109
],
21292110
)
2130-
def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2131-
"""Parametrized test for matvec and vecmat functions."""
2111+
def test_matrix_vector_ops(self, func, x_shape, y_shape, np_func, batch_axis):
2112+
"""Test all matrix-vector helper functions."""
21322113
rng = np.random.default_rng(seed=utt.fetch_seed())
21332114

21342115
# Create PyTensor variables with appropriate dimensions
@@ -2146,18 +2127,25 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
21462127
else:
21472128
y = tensor3()
21482129

2149-
# Apply the function
2130+
# Test basic functionality
21502131
z = func(x, y)
21512132
f = function([x, y], z)
21522133

2153-
# Create random values
21542134
x_val = random(*x_shape, rng=rng).astype(config.floatX)
21552135
y_val = random(*y_shape, rng=rng).astype(config.floatX)
21562136

2157-
# Compare with the expected result
2158-
expected = make_expected(x_val, y_val)
2137+
expected = np_func(x_val, y_val)
21592138
np.testing.assert_allclose(f(x_val, y_val), expected)
21602139

2140+
# Test with dtype parameter (to improve code coverage)
2141+
# Use float64 to ensure we can detect the difference
2142+
z_dtype = func(x, y, dtype="float64")
2143+
f_dtype = function([x, y], z_dtype)
2144+
2145+
result = f_dtype(x_val, y_val)
2146+
assert result.dtype == np.float64
2147+
np.testing.assert_allclose(result, expected)
2148+
21612149

21622150
class TestTensordot:
21632151
def TensorDot(self, axes):

0 commit comments

Comments
 (0)