Skip to content

Commit 7dcafcf

Browse files
twieckiclaude
andcommitted
Address PR feedback
- Remove axis parameter from vecdot (no longer needed) - Update type annotations to use TensorLike - Add proper return type annotations - Improve docstrings with examples - Simplify test implementation and use pytest.parametrize - Use simpler implementation for batched operations 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent b248e5b commit 7dcafcf

File tree

2 files changed

+115
-102
lines changed

2 files changed

+115
-102
lines changed

pytensor/tensor/math.py

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,66 +2842,55 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
28422842

28432843

28442844
def vecdot(
2845-
x1: "ArrayLike",
2846-
x2: "ArrayLike",
2847-
axis: int = -1,
2845+
x1: "TensorLike",
2846+
x2: "TensorLike",
28482847
dtype: Optional["DTypeLike"] = None,
2849-
):
2850-
"""Compute the dot product of two vectors along specified dimensions.
2848+
) -> "TensorVariable":
2849+
"""Compute the vector dot product of two arrays.
28512850
28522851
Parameters
28532852
----------
28542853
x1, x2
2855-
Input arrays, scalars not allowed.
2856-
axis
2857-
The axis along which to compute the dot product. By default, the last
2858-
axes of the inputs are used.
2854+
Input arrays with the same shape.
28592855
dtype
2860-
The desired data-type for the array. If not given, then the type will
2856+
The desired data-type for the result. If not given, then the type will
28612857
be determined as the minimum type required to hold the objects in the
28622858
sequence.
28632859
28642860
Returns
28652861
-------
2866-
out : ndarray
2867-
The vector dot product of the inputs computed along the specified axes.
2862+
TensorVariable
2863+
The vector dot product of the inputs.
28682864
28692865
Notes
28702866
-----
2871-
This is similar to `dot` but with broadcasting. It computes the dot product
2872-
along the specified axes, treating these as vectors, and broadcasts across
2873-
the remaining axes.
2867+
This is similar to `np.vecdot` and computes the dot product of
2868+
vectors along the last axis of both inputs. Broadcasting is supported
2869+
across all other dimensions.
2870+
2871+
Examples
2872+
--------
2873+
>>> import pytensor.tensor as pt
2874+
>>> x = pt.matrix("x")
2875+
>>> y = pt.matrix("y")
2876+
>>> z = pt.vecdot(x, y)
2877+
>>> # Equivalent to np.sum(x * y, axis=-1)
28742878
"""
28752879
x1 = as_tensor_variable(x1)
28762880
x2 = as_tensor_variable(x2)
28772881

2878-
# Handle negative axis
2879-
if axis < 0:
2880-
x1_axis = axis % x1.type.ndim
2881-
x2_axis = axis % x2.type.ndim
2882-
else:
2883-
x1_axis = axis
2884-
x2_axis = axis
2885-
2886-
# Move the axes to the end for dot product calculation
2887-
x1_perm = list(range(x1.type.ndim))
2888-
x1_perm.append(x1_perm.pop(x1_axis))
2889-
x1_transposed = x1.transpose(x1_perm)
2890-
2891-
x2_perm = list(range(x2.type.ndim))
2892-
x2_perm.append(x2_perm.pop(x2_axis))
2893-
x2_transposed = x2.transpose(x2_perm)
2894-
2895-
# Use the inner product operation
2896-
out = _inner_prod(x1_transposed, x2_transposed)
2882+
# Use the inner product operation along the last axis
2883+
out = _inner_prod(x1, x2)
28972884

28982885
if dtype is not None:
28992886
out = out.astype(dtype)
29002887

29012888
return out
29022889

29032890

2904-
def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
2891+
def matvec(
2892+
x1: "TensorLike", x2: "TensorLike", dtype: Optional["DTypeLike"] = None
2893+
) -> "TensorVariable":
29052894
"""Compute the matrix-vector product.
29062895
29072896
Parameters
@@ -2911,20 +2900,35 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29112900
x2
29122901
Input array for the vector with shape (..., K).
29132902
dtype
2914-
The desired data-type for the array. If not given, then the type will
2903+
The desired data-type for the result. If not given, then the type will
29152904
be determined as the minimum type required to hold the objects in the
29162905
sequence.
29172906
29182907
Returns
29192908
-------
2920-
out : ndarray
2909+
TensorVariable
29212910
The matrix-vector product with shape (..., M).
29222911
29232912
Notes
29242913
-----
2925-
This is similar to `matmul` where the second argument is a vector,
2926-
but with different broadcasting rules. Broadcasting happens over all but
2927-
the last dimension of x1 and all dimensions of x2 except the last.
2914+
This is equivalent to `numpy.matmul` where the second argument is a vector,
2915+
but with more intuitive broadcasting rules. Broadcasting happens over all but
2916+
the last two dimensions of x1 and all dimensions of x2 except the last.
2917+
2918+
Examples
2919+
--------
2920+
>>> import pytensor.tensor as pt
2921+
>>> import numpy as np
2922+
>>> # Matrix-vector product
2923+
>>> A = pt.matrix("A") # shape (M, K)
2924+
>>> v = pt.vector("v") # shape (K,)
2925+
>>> result = pt.matvec(A, v) # shape (M,)
2926+
>>> # Equivalent to np.matmul(A, v)
2927+
>>>
2928+
>>> # Batched matrix-vector product
2929+
>>> batched_A = pt.tensor3("A") # shape (B, M, K)
2930+
>>> batched_v = pt.matrix("v") # shape (B, K)
2931+
>>> result = pt.matvec(batched_A, batched_v) # shape (B, M)
29282932
"""
29292933
x1 = as_tensor_variable(x1)
29302934
x2 = as_tensor_variable(x2)
@@ -2937,7 +2941,9 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29372941
return out
29382942

29392943

2940-
def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
2944+
def vecmat(
2945+
x1: "TensorLike", x2: "TensorLike", dtype: Optional["DTypeLike"] = None
2946+
) -> "TensorVariable":
29412947
"""Compute the vector-matrix product.
29422948
29432949
Parameters
@@ -2947,20 +2953,35 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29472953
x2
29482954
Input array for the matrix with shape (..., K, N).
29492955
dtype
2950-
The desired data-type for the array. If not given, then the type will
2956+
The desired data-type for the result. If not given, then the type will
29512957
be determined as the minimum type required to hold the objects in the
29522958
sequence.
29532959
29542960
Returns
29552961
-------
2956-
out : ndarray
2962+
TensorVariable
29572963
The vector-matrix product with shape (..., N).
29582964
29592965
Notes
29602966
-----
2961-
This is similar to `matmul` where the first argument is a vector,
2962-
but with different broadcasting rules. Broadcasting happens over all but
2967+
This is equivalent to `numpy.matmul` where the first argument is a vector,
2968+
but with more intuitive broadcasting rules. Broadcasting happens over all but
29632969
the last dimension of x1 and all but the last two dimensions of x2.
2970+
2971+
Examples
2972+
--------
2973+
>>> import pytensor.tensor as pt
2974+
>>> import numpy as np
2975+
>>> # Vector-matrix product
2976+
>>> v = pt.vector("v") # shape (K,)
2977+
>>> A = pt.matrix("A") # shape (K, N)
2978+
>>> result = pt.vecmat(v, A) # shape (N,)
2979+
>>> # Equivalent to np.matmul(v, A)
2980+
>>>
2981+
>>> # Batched vector-matrix product
2982+
>>> batched_v = pt.matrix("v") # shape (B, K)
2983+
>>> batched_A = pt.tensor3("A") # shape (B, K, N)
2984+
>>> result = pt.vecmat(batched_v, batched_A) # shape (B, N)
29642985
"""
29652986
x1 = as_tensor_variable(x1)
29662987
x2 = as_tensor_variable(x2)

tests/tensor/test_math.py

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def is_super_shape(var1, var2):
20812081

20822082
class TestMatrixVectorOps:
20832083
def test_vecdot(self):
2084-
"""Test vecdot function with various input shapes and axis."""
2084+
"""Test vecdot function with various input shapes."""
20852085
rng = np.random.default_rng(seed=utt.fetch_seed())
20862086

20872087
# Test vector-vector
@@ -2093,77 +2093,69 @@ def test_vecdot(self):
20932093
y_val = random(5, rng=rng).astype(config.floatX)
20942094
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
20952095

2096-
# Test with axis parameter
2097-
x = matrix()
2098-
y = matrix()
2099-
z0 = vecdot(x, y, axis=0)
2100-
z1 = vecdot(x, y, axis=1)
2101-
f0 = function([x, y], z0)
2102-
f1 = function([x, y], z1)
2103-
2104-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2105-
y_val = random(3, 4, rng=rng).astype(config.floatX)
2106-
np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0))
2107-
np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1))
2108-
21092096
# Test batched vectors
21102097
x = tensor3()
21112098
y = tensor3()
2112-
z = vecdot(x, y, axis=2)
2099+
z = vecdot(x, y)
21132100
f = function([x, y], z)
21142101

21152102
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
21162103
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2117-
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))
2118-
2119-
def test_matvec(self):
2120-
"""Test matvec function with various input shapes."""
2121-
rng = np.random.default_rng(seed=utt.fetch_seed())
2122-
2123-
# Test matrix-vector
2124-
x = matrix()
2125-
y = vector()
2126-
z = matvec(x, y)
2127-
f = function([x, y], z)
2128-
2129-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2130-
y_val = random(4, rng=rng).astype(config.floatX)
2131-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2132-
2133-
# Test batched
2134-
x = tensor3()
2135-
y = matrix()
2136-
z = matvec(x, y)
2137-
f = function([x, y], z)
2138-
2139-
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2140-
y_val = random(2, 4, rng=rng).astype(config.floatX)
2141-
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2104+
expected = np.sum(x_val * y_val, axis=-1)
21422105
np.testing.assert_allclose(f(x_val, y_val), expected)
21432106

2144-
def test_vecmat(self):
2145-
"""Test vecmat function with various input shapes."""
2107+
@pytest.mark.parametrize(
2108+
"func,x_shape,y_shape,make_expected",
2109+
[
2110+
# matvec tests - Matrix(M,K) @ Vector(K) -> Vector(M)
2111+
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y)),
2112+
# matvec batch tests - Tensor3(B,M,K) @ Matrix(B,K) -> Matrix(B,M)
2113+
(
2114+
matvec,
2115+
(2, 3, 4),
2116+
(2, 4),
2117+
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2118+
),
2119+
# vecmat tests - Vector(K) @ Matrix(K,N) -> Vector(N)
2120+
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y)),
2121+
# vecmat batch tests - Matrix(B,K) @ Tensor3(B,K,N) -> Matrix(B,N)
2122+
(
2123+
vecmat,
2124+
(2, 3),
2125+
(2, 3, 4),
2126+
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2127+
),
2128+
],
2129+
)
2130+
def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2131+
"""Parametrized test for matvec and vecmat functions."""
21462132
rng = np.random.default_rng(seed=utt.fetch_seed())
21472133

2148-
# Test vector-matrix
2149-
x = vector()
2150-
y = matrix()
2151-
z = vecmat(x, y)
2152-
f = function([x, y], z)
2134+
# Create PyTensor variables with appropriate dimensions
2135+
if len(x_shape) == 1:
2136+
x = vector()
2137+
elif len(x_shape) == 2:
2138+
x = matrix()
2139+
else:
2140+
x = tensor3()
21532141

2154-
x_val = random(3, rng=rng).astype(config.floatX)
2155-
y_val = random(3, 4, rng=rng).astype(config.floatX)
2156-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2142+
if len(y_shape) == 1:
2143+
y = vector()
2144+
elif len(y_shape) == 2:
2145+
y = matrix()
2146+
else:
2147+
y = tensor3()
21572148

2158-
# Test batched
2159-
x = matrix()
2160-
y = tensor3()
2161-
z = vecmat(x, y)
2149+
# Apply the function
2150+
z = func(x, y)
21622151
f = function([x, y], z)
21632152

2164-
x_val = random(2, 3, rng=rng).astype(config.floatX)
2165-
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2166-
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2153+
# Create random values
2154+
x_val = random(*x_shape, rng=rng).astype(config.floatX)
2155+
y_val = random(*y_shape, rng=rng).astype(config.floatX)
2156+
2157+
# Compare with the expected result
2158+
expected = make_expected(x_val, y_val)
21672159
np.testing.assert_allclose(f(x_val, y_val), expected)
21682160

21692161
def test_matmul(self):

0 commit comments

Comments
 (0)