Skip to content

Commit 6f0f14c

Browse files
twieckiclaude
andcommitted
Address PR feedback
- Remove axis parameter from vecdot (no longer needed) - Update type annotations to use TensorLike - Add proper return type annotations - Improve docstrings with examples - Simplify test implementation and use pytest.parametrize - Use simpler implementation for batched operations 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 0ef1ffd commit 6f0f14c

File tree

2 files changed

+115
-102
lines changed

2 files changed

+115
-102
lines changed

pytensor/tensor/math.py

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4123,66 +4123,55 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
41234123

41244124

41254125
def vecdot(
4126-
x1: "ArrayLike",
4127-
x2: "ArrayLike",
4128-
axis: int = -1,
4126+
x1: "TensorLike",
4127+
x2: "TensorLike",
41294128
dtype: Optional["DTypeLike"] = None,
4130-
):
4131-
"""Compute the dot product of two vectors along specified dimensions.
4129+
) -> "TensorVariable":
4130+
"""Compute the vector dot product of two arrays.
41324131
41334132
Parameters
41344133
----------
41354134
x1, x2
4136-
Input arrays, scalars not allowed.
4137-
axis
4138-
The axis along which to compute the dot product. By default, the last
4139-
axes of the inputs are used.
4135+
Input arrays with the same shape.
41404136
dtype
4141-
The desired data-type for the array. If not given, then the type will
4137+
The desired data-type for the result. If not given, then the type will
41424138
be determined as the minimum type required to hold the objects in the
41434139
sequence.
41444140
41454141
Returns
41464142
-------
4147-
out : ndarray
4148-
The vector dot product of the inputs computed along the specified axes.
4143+
TensorVariable
4144+
The vector dot product of the inputs.
41494145
41504146
Notes
41514147
-----
4152-
This is similar to `dot` but with broadcasting. It computes the dot product
4153-
along the specified axes, treating these as vectors, and broadcasts across
4154-
the remaining axes.
4148+
This is similar to `np.vecdot` and computes the dot product of
4149+
vectors along the last axis of both inputs. Broadcasting is supported
4150+
across all other dimensions.
4151+
4152+
Examples
4153+
--------
4154+
>>> import pytensor.tensor as pt
4155+
>>> x = pt.matrix("x")
4156+
>>> y = pt.matrix("y")
4157+
>>> z = pt.vecdot(x, y)
4158+
>>> # Equivalent to np.sum(x * y, axis=-1)
41554159
"""
41564160
x1 = as_tensor_variable(x1)
41574161
x2 = as_tensor_variable(x2)
41584162

4159-
# Handle negative axis
4160-
if axis < 0:
4161-
x1_axis = axis % x1.type.ndim
4162-
x2_axis = axis % x2.type.ndim
4163-
else:
4164-
x1_axis = axis
4165-
x2_axis = axis
4166-
4167-
# Move the axes to the end for dot product calculation
4168-
x1_perm = list(range(x1.type.ndim))
4169-
x1_perm.append(x1_perm.pop(x1_axis))
4170-
x1_transposed = x1.transpose(x1_perm)
4171-
4172-
x2_perm = list(range(x2.type.ndim))
4173-
x2_perm.append(x2_perm.pop(x2_axis))
4174-
x2_transposed = x2.transpose(x2_perm)
4175-
4176-
# Use the inner product operation
4177-
out = _inner_prod(x1_transposed, x2_transposed)
4163+
# Use the inner product operation along the last axis
4164+
out = _inner_prod(x1, x2)
41784165

41794166
if dtype is not None:
41804167
out = out.astype(dtype)
41814168

41824169
return out
41834170

41844171

4185-
def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
4172+
def matvec(
4173+
x1: "TensorLike", x2: "TensorLike", dtype: Optional["DTypeLike"] = None
4174+
) -> "TensorVariable":
41864175
"""Compute the matrix-vector product.
41874176
41884177
Parameters
@@ -4192,20 +4181,35 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
41924181
x2
41934182
Input array for the vector with shape (..., K).
41944183
dtype
4195-
The desired data-type for the array. If not given, then the type will
4184+
The desired data-type for the result. If not given, then the type will
41964185
be determined as the minimum type required to hold the objects in the
41974186
sequence.
41984187
41994188
Returns
42004189
-------
4201-
out : ndarray
4190+
TensorVariable
42024191
The matrix-vector product with shape (..., M).
42034192
42044193
Notes
42054194
-----
4206-
This is similar to `matmul` where the second argument is a vector,
4207-
but with different broadcasting rules. Broadcasting happens over all but
4208-
the last dimension of x1 and all dimensions of x2 except the last.
4195+
This is equivalent to `numpy.matmul` where the second argument is a vector,
4196+
but with more intuitive broadcasting rules. Broadcasting happens over all but
4197+
the last two dimensions of x1 and all dimensions of x2 except the last.
4198+
4199+
Examples
4200+
--------
4201+
>>> import pytensor.tensor as pt
4202+
>>> import numpy as np
4203+
>>> # Matrix-vector product
4204+
>>> A = pt.matrix("A") # shape (M, K)
4205+
>>> v = pt.vector("v") # shape (K,)
4206+
>>> result = pt.matvec(A, v) # shape (M,)
4207+
>>> # Equivalent to np.matmul(A, v)
4208+
>>>
4209+
>>> # Batched matrix-vector product
4210+
>>> batched_A = pt.tensor3("A") # shape (B, M, K)
4211+
>>> batched_v = pt.matrix("v") # shape (B, K)
4212+
>>> result = pt.matvec(batched_A, batched_v) # shape (B, M)
42094213
"""
42104214
x1 = as_tensor_variable(x1)
42114215
x2 = as_tensor_variable(x2)
@@ -4218,7 +4222,9 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42184222
return out
42194223

42204224

4221-
def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
4225+
def vecmat(
4226+
x1: "TensorLike", x2: "TensorLike", dtype: Optional["DTypeLike"] = None
4227+
) -> "TensorVariable":
42224228
"""Compute the vector-matrix product.
42234229
42244230
Parameters
@@ -4228,20 +4234,35 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
42284234
x2
42294235
Input array for the matrix with shape (..., K, N).
42304236
dtype
4231-
The desired data-type for the array. If not given, then the type will
4237+
The desired data-type for the result. If not given, then the type will
42324238
be determined as the minimum type required to hold the objects in the
42334239
sequence.
42344240
42354241
Returns
42364242
-------
4237-
out : ndarray
4243+
TensorVariable
42384244
The vector-matrix product with shape (..., N).
42394245
42404246
Notes
42414247
-----
4242-
This is similar to `matmul` where the first argument is a vector,
4243-
but with different broadcasting rules. Broadcasting happens over all but
4248+
This is equivalent to `numpy.matmul` where the first argument is a vector,
4249+
but with more intuitive broadcasting rules. Broadcasting happens over all but
42444250
the last dimension of x1 and all but the last two dimensions of x2.
4251+
4252+
Examples
4253+
--------
4254+
>>> import pytensor.tensor as pt
4255+
>>> import numpy as np
4256+
>>> # Vector-matrix product
4257+
>>> v = pt.vector("v") # shape (K,)
4258+
>>> A = pt.matrix("A") # shape (K, N)
4259+
>>> result = pt.vecmat(v, A) # shape (N,)
4260+
>>> # Equivalent to np.matmul(v, A)
4261+
>>>
4262+
>>> # Batched vector-matrix product
4263+
>>> batched_v = pt.matrix("v") # shape (B, K)
4264+
>>> batched_A = pt.tensor3("A") # shape (B, K, N)
4265+
>>> result = pt.vecmat(batched_v, batched_A) # shape (B, N)
42454266
"""
42464267
x1 = as_tensor_variable(x1)
42474268
x2 = as_tensor_variable(x2)

tests/tensor/test_math.py

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def is_super_shape(var1, var2):
20812081

20822082
class TestMatrixVectorOps:
20832083
def test_vecdot(self):
2084-
"""Test vecdot function with various input shapes and axis."""
2084+
"""Test vecdot function with various input shapes."""
20852085
rng = np.random.default_rng(seed=utt.fetch_seed())
20862086

20872087
# Test vector-vector
@@ -2093,77 +2093,69 @@ def test_vecdot(self):
20932093
y_val = random(5, rng=rng).astype(config.floatX)
20942094
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
20952095

2096-
# Test with axis parameter
2097-
x = matrix()
2098-
y = matrix()
2099-
z0 = vecdot(x, y, axis=0)
2100-
z1 = vecdot(x, y, axis=1)
2101-
f0 = function([x, y], z0)
2102-
f1 = function([x, y], z1)
2103-
2104-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2105-
y_val = random(3, 4, rng=rng).astype(config.floatX)
2106-
np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0))
2107-
np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1))
2108-
21092096
# Test batched vectors
21102097
x = tensor3()
21112098
y = tensor3()
2112-
z = vecdot(x, y, axis=2)
2099+
z = vecdot(x, y)
21132100
f = function([x, y], z)
21142101

21152102
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
21162103
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2117-
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))
2118-
2119-
def test_matvec(self):
2120-
"""Test matvec function with various input shapes."""
2121-
rng = np.random.default_rng(seed=utt.fetch_seed())
2122-
2123-
# Test matrix-vector
2124-
x = matrix()
2125-
y = vector()
2126-
z = matvec(x, y)
2127-
f = function([x, y], z)
2128-
2129-
x_val = random(3, 4, rng=rng).astype(config.floatX)
2130-
y_val = random(4, rng=rng).astype(config.floatX)
2131-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2132-
2133-
# Test batched
2134-
x = tensor3()
2135-
y = matrix()
2136-
z = matvec(x, y)
2137-
f = function([x, y], z)
2138-
2139-
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2140-
y_val = random(2, 4, rng=rng).astype(config.floatX)
2141-
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2104+
expected = np.sum(x_val * y_val, axis=-1)
21422105
np.testing.assert_allclose(f(x_val, y_val), expected)
21432106

2144-
def test_vecmat(self):
2145-
"""Test vecmat function with various input shapes."""
2107+
@pytest.mark.parametrize(
2108+
"func,x_shape,y_shape,make_expected",
2109+
[
2110+
# matvec tests - Matrix(M,K) @ Vector(K) -> Vector(M)
2111+
(matvec, (3, 4), (4,), lambda x, y: np.dot(x, y)),
2112+
# matvec batch tests - Tensor3(B,M,K) @ Matrix(B,K) -> Matrix(B,M)
2113+
(
2114+
matvec,
2115+
(2, 3, 4),
2116+
(2, 4),
2117+
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2118+
),
2119+
# vecmat tests - Vector(K) @ Matrix(K,N) -> Vector(N)
2120+
(vecmat, (3,), (3, 4), lambda x, y: np.dot(x, y)),
2121+
# vecmat batch tests - Matrix(B,K) @ Tensor3(B,K,N) -> Matrix(B,N)
2122+
(
2123+
vecmat,
2124+
(2, 3),
2125+
(2, 3, 4),
2126+
lambda x, y: np.array([np.dot(x[i], y[i]) for i in range(len(x))]),
2127+
),
2128+
],
2129+
)
2130+
def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2131+
"""Parametrized test for matvec and vecmat functions."""
21462132
rng = np.random.default_rng(seed=utt.fetch_seed())
21472133

2148-
# Test vector-matrix
2149-
x = vector()
2150-
y = matrix()
2151-
z = vecmat(x, y)
2152-
f = function([x, y], z)
2134+
# Create PyTensor variables with appropriate dimensions
2135+
if len(x_shape) == 1:
2136+
x = vector()
2137+
elif len(x_shape) == 2:
2138+
x = matrix()
2139+
else:
2140+
x = tensor3()
21532141

2154-
x_val = random(3, rng=rng).astype(config.floatX)
2155-
y_val = random(3, 4, rng=rng).astype(config.floatX)
2156-
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2142+
if len(y_shape) == 1:
2143+
y = vector()
2144+
elif len(y_shape) == 2:
2145+
y = matrix()
2146+
else:
2147+
y = tensor3()
21572148

2158-
# Test batched
2159-
x = matrix()
2160-
y = tensor3()
2161-
z = vecmat(x, y)
2149+
# Apply the function
2150+
z = func(x, y)
21622151
f = function([x, y], z)
21632152

2164-
x_val = random(2, 3, rng=rng).astype(config.floatX)
2165-
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2166-
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2153+
# Create random values
2154+
x_val = random(*x_shape, rng=rng).astype(config.floatX)
2155+
y_val = random(*y_shape, rng=rng).astype(config.floatX)
2156+
2157+
# Compare with the expected result
2158+
expected = make_expected(x_val, y_val)
21672159
np.testing.assert_allclose(f(x_val, y_val), expected)
21682160

21692161
def test_matmul(self):

0 commit comments

Comments
 (0)