Skip to content

Commit 12598e3

Browse files
committed
Expose vecdot, vecmat and matvec helpers
Add three new functions that expose the underlying Blockwise operations: - vecdot: Computes dot products between vectors with broadcasting - matvec: Computes matrix-vector products with broadcasting - vecmat: Computes vector-matrix products with broadcasting These match the NumPy API for similar operations and complement the existing matmul function. Each comes with appropriate error handling, parameter validation, and comprehensive test coverage. Fixes #1237
1 parent 89d5366 commit 12598e3

File tree

2 files changed

+352
-0
lines changed

2 files changed

+352
-0
lines changed

pytensor/tensor/math.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4122,6 +4122,176 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
41224122
return out
41234123

41244124

4125+
def vecdot(
4126+
x1: "ArrayLike",
4127+
x2: "ArrayLike",
4128+
axis: int = -1,
4129+
dtype: Optional["DTypeLike"] = None,
4130+
):
4131+
"""Compute the dot product of two vectors along specified dimensions.
4132+
4133+
Parameters
4134+
----------
4135+
x1, x2
4136+
Input arrays, scalars not allowed.
4137+
axis
4138+
The axis along which to compute the dot product. By default, the last
4139+
axes of the inputs are used.
4140+
dtype
4141+
The desired data-type for the array. If not given, then the type will
4142+
be determined as the minimum type required to hold the objects in the
4143+
sequence.
4144+
4145+
Returns
4146+
-------
4147+
out : ndarray
4148+
The vector dot product of the inputs computed along the specified axes.
4149+
4150+
Raises
4151+
------
4152+
ValueError
4153+
If either input is a scalar value.
4154+
4155+
Notes
4156+
-----
4157+
This is similar to `dot` but with broadcasting. It computes the dot product
4158+
along the specified axes, treating these as vectors, and broadcasts across
4159+
the remaining axes.
4160+
"""
4161+
x1 = as_tensor_variable(x1)
4162+
x2 = as_tensor_variable(x2)
4163+
4164+
if x1.type.ndim == 0 or x2.type.ndim == 0:
4165+
raise ValueError("vecdot operand cannot be scalar")
4166+
4167+
# Handle negative axis
4168+
if axis < 0:
4169+
x1_axis = axis % x1.type.ndim
4170+
x2_axis = axis % x2.type.ndim
4171+
else:
4172+
x1_axis = axis
4173+
x2_axis = axis
4174+
4175+
# Move the axes to the end for dot product calculation
4176+
x1_perm = list(range(x1.type.ndim))
4177+
x1_perm.append(x1_perm.pop(x1_axis))
4178+
x1_transposed = x1.transpose(x1_perm)
4179+
4180+
x2_perm = list(range(x2.type.ndim))
4181+
x2_perm.append(x2_perm.pop(x2_axis))
4182+
x2_transposed = x2.transpose(x2_perm)
4183+
4184+
# Use the inner product operation
4185+
out = _inner_prod(x1_transposed, x2_transposed)
4186+
4187+
if dtype is not None:
4188+
out = out.astype(dtype)
4189+
4190+
return out
4191+
4192+
4193+
def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
4194+
"""Compute the matrix-vector product.
4195+
4196+
Parameters
4197+
----------
4198+
x1
4199+
Input array for the matrix with shape (..., M, K).
4200+
x2
4201+
Input array for the vector with shape (..., K).
4202+
dtype
4203+
The desired data-type for the array. If not given, then the type will
4204+
be determined as the minimum type required to hold the objects in the
4205+
sequence.
4206+
4207+
Returns
4208+
-------
4209+
out : ndarray
4210+
The matrix-vector product with shape (..., M).
4211+
4212+
Raises
4213+
------
4214+
ValueError
4215+
If any input is a scalar or if the trailing dimension of x2 does not match
4216+
the second-to-last dimension of x1.
4217+
4218+
Notes
4219+
-----
4220+
This is similar to `matmul` where the second argument is a vector,
4221+
but with different broadcasting rules. Broadcasting happens over all but
4222+
the last dimension of x1 and all dimensions of x2 except the last.
4223+
"""
4224+
x1 = as_tensor_variable(x1)
4225+
x2 = as_tensor_variable(x2)
4226+
4227+
if x1.type.ndim == 0 or x2.type.ndim == 0:
4228+
raise ValueError("matvec operand cannot be scalar")
4229+
4230+
if x1.type.ndim < 2:
4231+
raise ValueError("First input to matvec must have at least 2 dimensions")
4232+
4233+
if x2.type.ndim < 1:
4234+
raise ValueError("Second input to matvec must have at least 1 dimension")
4235+
4236+
out = _matrix_vec_prod(x1, x2)
4237+
4238+
if dtype is not None:
4239+
out = out.astype(dtype)
4240+
4241+
return out
4242+
4243+
4244+
def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
4245+
"""Compute the vector-matrix product.
4246+
4247+
Parameters
4248+
----------
4249+
x1
4250+
Input array for the vector with shape (..., K).
4251+
x2
4252+
Input array for the matrix with shape (..., K, N).
4253+
dtype
4254+
The desired data-type for the array. If not given, then the type will
4255+
be determined as the minimum type required to hold the objects in the
4256+
sequence.
4257+
4258+
Returns
4259+
-------
4260+
out : ndarray
4261+
The vector-matrix product with shape (..., N).
4262+
4263+
Raises
4264+
------
4265+
ValueError
4266+
If any input is a scalar or if the last dimension of x1 does not match
4267+
the second-to-last dimension of x2.
4268+
4269+
Notes
4270+
-----
4271+
This is similar to `matmul` where the first argument is a vector,
4272+
but with different broadcasting rules. Broadcasting happens over all but
4273+
the last dimension of x1 and all but the last two dimensions of x2.
4274+
"""
4275+
x1 = as_tensor_variable(x1)
4276+
x2 = as_tensor_variable(x2)
4277+
4278+
if x1.type.ndim == 0 or x2.type.ndim == 0:
4279+
raise ValueError("vecmat operand cannot be scalar")
4280+
4281+
if x1.type.ndim < 1:
4282+
raise ValueError("First input to vecmat must have at least 1 dimension")
4283+
4284+
if x2.type.ndim < 2:
4285+
raise ValueError("Second input to vecmat must have at least 2 dimensions")
4286+
4287+
out = _vec_matrix_prod(x1, x2)
4288+
4289+
if dtype is not None:
4290+
out = out.astype(dtype)
4291+
4292+
return out
4293+
4294+
41254295
@_vectorize_node.register(Dot)
41264296
def vectorize_node_dot(op, node, batched_x, batched_y):
41274297
old_x, old_y = node.inputs
@@ -4218,6 +4388,9 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
42184388
"max_and_argmax",
42194389
"max",
42204390
"matmul",
4391+
"vecdot",
4392+
"matvec",
4393+
"vecmat",
42214394
"argmax",
42224395
"min",
42234396
"argmin",

tests/tensor/test_math.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
logaddexp,
9090
logsumexp,
9191
matmul,
92+
matvec,
9293
max,
9394
max_and_argmax,
9495
maximum,
@@ -123,6 +124,8 @@
123124
true_div,
124125
trunc,
125126
var,
127+
vecdot,
128+
vecmat,
126129
)
127130
from pytensor.tensor.math import sum as pt_sum
128131
from pytensor.tensor.type import (
@@ -2076,6 +2079,182 @@ def is_super_shape(var1, var2):
20762079
assert is_super_shape(y, g)
20772080

20782081

2082+
class TestMatrixVectorOps:
2083+
def test_vecdot(self):
2084+
"""Test vecdot function with various input shapes and axis."""
2085+
rng = np.random.default_rng(seed=utt.fetch_seed())
2086+
2087+
# Test vector-vector
2088+
x = vector()
2089+
y = vector()
2090+
z = vecdot(x, y)
2091+
f = function([x, y], z)
2092+
x_val = random(5, rng=rng).astype(config.floatX)
2093+
y_val = random(5, rng=rng).astype(config.floatX)
2094+
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2095+
2096+
# Test with axis parameter
2097+
x = matrix()
2098+
y = matrix()
2099+
z0 = vecdot(x, y, axis=0)
2100+
z1 = vecdot(x, y, axis=1)
2101+
f0 = function([x, y], z0)
2102+
f1 = function([x, y], z1)
2103+
2104+
x_val = random(3, 4, rng=rng).astype(config.floatX)
2105+
y_val = random(3, 4, rng=rng).astype(config.floatX)
2106+
np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0))
2107+
np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1))
2108+
2109+
# Test batched vectors
2110+
x = tensor3()
2111+
y = tensor3()
2112+
z = vecdot(x, y, axis=2)
2113+
f = function([x, y], z)
2114+
2115+
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2116+
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2117+
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))
2118+
2119+
# Test error cases
2120+
x = scalar()
2121+
y = scalar()
2122+
with pytest.raises(ValueError):
2123+
vecdot(x, y)
2124+
2125+
def test_matvec(self):
2126+
"""Test matvec function with various input shapes."""
2127+
rng = np.random.default_rng(seed=utt.fetch_seed())
2128+
2129+
# Test matrix-vector
2130+
x = matrix()
2131+
y = vector()
2132+
z = matvec(x, y)
2133+
f = function([x, y], z)
2134+
2135+
x_val = random(3, 4, rng=rng).astype(config.floatX)
2136+
y_val = random(4, rng=rng).astype(config.floatX)
2137+
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2138+
2139+
# Test batched
2140+
x = tensor3()
2141+
y = matrix()
2142+
z = matvec(x, y)
2143+
f = function([x, y], z)
2144+
2145+
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2146+
y_val = random(2, 4, rng=rng).astype(config.floatX)
2147+
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2148+
np.testing.assert_allclose(f(x_val, y_val), expected)
2149+
2150+
# Test error cases
2151+
x = vector()
2152+
y = vector()
2153+
with pytest.raises(ValueError):
2154+
matvec(x, y)
2155+
2156+
x = scalar()
2157+
y = vector()
2158+
with pytest.raises(ValueError):
2159+
matvec(x, y)
2160+
2161+
def test_vecmat(self):
2162+
"""Test vecmat function with various input shapes."""
2163+
rng = np.random.default_rng(seed=utt.fetch_seed())
2164+
2165+
# Test vector-matrix
2166+
x = vector()
2167+
y = matrix()
2168+
z = vecmat(x, y)
2169+
f = function([x, y], z)
2170+
2171+
x_val = random(3, rng=rng).astype(config.floatX)
2172+
y_val = random(3, 4, rng=rng).astype(config.floatX)
2173+
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))
2174+
2175+
# Test batched
2176+
x = matrix()
2177+
y = tensor3()
2178+
z = vecmat(x, y)
2179+
f = function([x, y], z)
2180+
2181+
x_val = random(2, 3, rng=rng).astype(config.floatX)
2182+
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2183+
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
2184+
np.testing.assert_allclose(f(x_val, y_val), expected)
2185+
2186+
# Test error cases
2187+
x = matrix()
2188+
y = vector()
2189+
with pytest.raises(ValueError):
2190+
vecmat(x, y)
2191+
2192+
x = scalar()
2193+
y = matrix()
2194+
with pytest.raises(ValueError):
2195+
vecmat(x, y)
2196+
2197+
def test_matmul(self):
2198+
"""Test matmul function with various input shapes."""
2199+
rng = np.random.default_rng(seed=utt.fetch_seed())
2200+
2201+
# Test matrix-matrix
2202+
x = matrix()
2203+
y = matrix()
2204+
z = matmul(x, y)
2205+
f = function([x, y], z)
2206+
2207+
x_val = random(3, 4, rng=rng).astype(config.floatX)
2208+
y_val = random(4, 5, rng=rng).astype(config.floatX)
2209+
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2210+
2211+
# Test vector-matrix
2212+
x = vector()
2213+
y = matrix()
2214+
z = matmul(x, y)
2215+
f = function([x, y], z)
2216+
2217+
x_val = random(3, rng=rng).astype(config.floatX)
2218+
y_val = random(3, 4, rng=rng).astype(config.floatX)
2219+
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2220+
2221+
# Test matrix-vector
2222+
x = matrix()
2223+
y = vector()
2224+
z = matmul(x, y)
2225+
f = function([x, y], z)
2226+
2227+
x_val = random(3, 4, rng=rng).astype(config.floatX)
2228+
y_val = random(4, rng=rng).astype(config.floatX)
2229+
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2230+
2231+
# Test vector-vector
2232+
x = vector()
2233+
y = vector()
2234+
z = matmul(x, y)
2235+
f = function([x, y], z)
2236+
2237+
x_val = random(3, rng=rng).astype(config.floatX)
2238+
y_val = random(3, rng=rng).astype(config.floatX)
2239+
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2240+
2241+
# Test batched
2242+
x = tensor3()
2243+
y = tensor3()
2244+
z = matmul(x, y)
2245+
f = function([x, y], z)
2246+
2247+
x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
2248+
y_val = random(2, 4, 5, rng=rng).astype(config.floatX)
2249+
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))
2250+
2251+
# Test error cases
2252+
x = scalar()
2253+
y = scalar()
2254+
with pytest.raises(ValueError):
2255+
matmul(x, y)
2256+
2257+
20792258
class TestTensordot:
20802259
def TensorDot(self, axes):
20812260
# Since tensordot is no longer an op, mimic the old op signature

0 commit comments

Comments
 (0)