Skip to content

Commit 3bd1bcf

Browse files
authored
Expose vecdot, vecmat and matvec helpers (#1250)
1 parent 110e128 commit 3bd1bcf

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

pytensor/tensor/math.py

+151
Original file line numberDiff line numberDiff line change
@@ -4122,6 +4122,154 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
41224122
return out
41234123

41244124

4125+
def vecdot(
4126+
x1: TensorLike,
4127+
x2: TensorLike,
4128+
dtype: Optional["DTypeLike"] = None,
4129+
) -> TensorVariable:
4130+
"""Compute the vector dot product of two arrays.
4131+
4132+
Parameters
4133+
----------
4134+
x1, x2
4135+
Input arrays with the same shape.
4136+
dtype
4137+
The desired data-type for the result. If not given, then the type will
4138+
be determined as the minimum type required to hold the objects in the
4139+
sequence.
4140+
4141+
Returns
4142+
-------
4143+
TensorVariable
4144+
The vector dot product of the inputs.
4145+
4146+
Notes
4147+
-----
4148+
This is equivalent to `numpy.vecdot` and computes the dot product of
4149+
vectors along the last axis of both inputs. Broadcasting is supported
4150+
across all other dimensions.
4151+
4152+
Examples
4153+
--------
4154+
>>> import pytensor.tensor as pt
4155+
>>> # Vector dot product with shape (5,) inputs
4156+
>>> x = pt.vector("x", shape=(5,)) # shape (5,)
4157+
>>> y = pt.vector("y", shape=(5,)) # shape (5,)
4158+
>>> z = pt.vecdot(x, y) # scalar output
4159+
>>> # Equivalent to numpy.vecdot(x, y)
4160+
>>>
4161+
>>> # With batched inputs of shape (3, 5)
4162+
>>> x_batch = pt.matrix("x", shape=(3, 5)) # shape (3, 5)
4163+
>>> y_batch = pt.matrix("y", shape=(3, 5)) # shape (3, 5)
4164+
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
4165+
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
4166+
"""
4167+
out = _inner_prod(x1, x2)
4168+
4169+
if dtype is not None:
4170+
out = out.astype(dtype)
4171+
4172+
return out
4173+
4174+
4175+
def matvec(
4176+
x1: TensorLike, x2: TensorLike, dtype: Optional["DTypeLike"] = None
4177+
) -> TensorVariable:
4178+
"""Compute the matrix-vector product.
4179+
4180+
Parameters
4181+
----------
4182+
x1
4183+
Input array for the matrix with shape (..., M, K).
4184+
x2
4185+
Input array for the vector with shape (..., K).
4186+
dtype
4187+
The desired data-type for the result. If not given, then the type will
4188+
be determined as the minimum type required to hold the objects in the
4189+
sequence.
4190+
4191+
Returns
4192+
-------
4193+
TensorVariable
4194+
The matrix-vector product with shape (..., M).
4195+
4196+
Notes
4197+
-----
4198+
This is equivalent to `numpy.matvec` and computes the matrix-vector product
4199+
with broadcasting over batch dimensions.
4200+
4201+
Examples
4202+
--------
4203+
>>> import pytensor.tensor as pt
4204+
>>> # Matrix-vector product
4205+
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
4206+
>>> v = pt.vector("v", shape=(4,)) # shape (4,)
4207+
>>> result = pt.matvec(A, v) # shape (3,)
4208+
>>> # Equivalent to numpy.matvec(A, v)
4209+
>>>
4210+
>>> # Batched matrix-vector product
4211+
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
4212+
>>> batched_v = pt.matrix("v", shape=(2, 4)) # shape (2, 4)
4213+
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
4214+
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
4215+
"""
4216+
out = _matrix_vec_prod(x1, x2)
4217+
4218+
if dtype is not None:
4219+
out = out.astype(dtype)
4220+
4221+
return out
4222+
4223+
4224+
def vecmat(
4225+
x1: TensorLike, x2: TensorLike, dtype: Optional["DTypeLike"] = None
4226+
) -> TensorVariable:
4227+
"""Compute the vector-matrix product.
4228+
4229+
Parameters
4230+
----------
4231+
x1
4232+
Input array for the vector with shape (..., K).
4233+
x2
4234+
Input array for the matrix with shape (..., K, N).
4235+
dtype
4236+
The desired data-type for the result. If not given, then the type will
4237+
be determined as the minimum type required to hold the objects in the
4238+
sequence.
4239+
4240+
Returns
4241+
-------
4242+
TensorVariable
4243+
The vector-matrix product with shape (..., N).
4244+
4245+
Notes
4246+
-----
4247+
This is equivalent to `numpy.vecmat` and computes the vector-matrix product
4248+
with broadcasting over batch dimensions.
4249+
4250+
Examples
4251+
--------
4252+
>>> import pytensor.tensor as pt
4253+
>>> # Vector-matrix product
4254+
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
4255+
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
4256+
>>> result = pt.vecmat(v, A) # shape (4,)
4257+
>>> # Equivalent to numpy.vecmat(v, A)
4258+
>>>
4259+
>>> # Batched vector-matrix product
4260+
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
4261+
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
4262+
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
4263+
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
4264+
"""
4265+
out = _vec_matrix_prod(x1, x2)
4266+
4267+
if dtype is not None:
4268+
out = out.astype(dtype)
4269+
4270+
return out
4271+
4272+
41254273
@_vectorize_node.register(Dot)
41264274
def vectorize_node_dot(op, node, batched_x, batched_y):
41274275
old_x, old_y = node.inputs
@@ -4218,6 +4366,9 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
42184366
"max_and_argmax",
42194367
"max",
42204368
"matmul",
4369+
"vecdot",
4370+
"matvec",
4371+
"vecmat",
42214372
"argmax",
42224373
"min",
42234374
"argmin",

tests/tensor/test_math.py

+68
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
logaddexp,
9090
logsumexp,
9191
matmul,
92+
matvec,
9293
max,
9394
max_and_argmax,
9495
maximum,
@@ -123,6 +124,8 @@
123124
true_div,
124125
trunc,
125126
var,
127+
vecdot,
128+
vecmat,
126129
)
127130
from pytensor.tensor.math import sum as pt_sum
128131
from pytensor.tensor.type import (
@@ -2076,6 +2079,71 @@ def is_super_shape(var1, var2):
20762079
assert is_super_shape(y, g)
20772080

20782081

2082+
def test_matrix_vector_ops():
2083+
"""Test vecdot, matvec, and vecmat helper functions."""
2084+
rng = np.random.default_rng(seed=utt.fetch_seed())
2085+
2086+
# Create test data with batch dimension (2)
2087+
batch_size = 2
2088+
dim_k = 4 # Common dimension
2089+
dim_m = 3 # Matrix rows
2090+
dim_n = 5 # Matrix columns
2091+
2092+
# Create input tensors with appropriate shapes
2093+
# For matvec: x1(b,m,k) @ x2(b,k) -> out(b,m)
2094+
# For vecmat: x1(b,k) @ x2(b,k,n) -> out(b,n)
2095+
2096+
# Create test values using config.floatX to match PyTensor's default dtype
2097+
mat_mk_val = random(batch_size, dim_m, dim_k, rng=rng).astype(config.floatX)
2098+
mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX)
2099+
vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX)
2100+
2101+
# Create tensor variables with matching dtype
2102+
mat_mk = tensor(
2103+
name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX
2104+
)
2105+
mat_kn = tensor(
2106+
name="mat_kn", shape=(batch_size, dim_k, dim_n), dtype=config.floatX
2107+
)
2108+
vec_k = tensor(name="vec_k", shape=(batch_size, dim_k), dtype=config.floatX)
2109+
2110+
# Test 1: vecdot with matching dimensions
2111+
vecdot_out = vecdot(vec_k, vec_k, dtype="int32")
2112+
vecdot_fn = function([vec_k], vecdot_out)
2113+
result = vecdot_fn(vec_k_val)
2114+
2115+
# Check dtype
2116+
assert result.dtype == np.int32
2117+
2118+
# Calculate expected manually
2119+
expected_vecdot = np.zeros((batch_size,), dtype=np.int32)
2120+
for i in range(batch_size):
2121+
expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i])
2122+
np.testing.assert_allclose(result, expected_vecdot)
2123+
2124+
# Test 2: matvec - matrix-vector product
2125+
matvec_out = matvec(mat_mk, vec_k)
2126+
matvec_fn = function([mat_mk, vec_k], matvec_out)
2127+
result_matvec = matvec_fn(mat_mk_val, vec_k_val)
2128+
2129+
# Calculate expected manually
2130+
expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX)
2131+
for i in range(batch_size):
2132+
expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i])
2133+
np.testing.assert_allclose(result_matvec, expected_matvec)
2134+
2135+
# Test 3: vecmat - vector-matrix product
2136+
vecmat_out = vecmat(vec_k, mat_kn)
2137+
vecmat_fn = function([vec_k, mat_kn], vecmat_out)
2138+
result_vecmat = vecmat_fn(vec_k_val, mat_kn_val)
2139+
2140+
# Calculate expected manually
2141+
expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX)
2142+
for i in range(batch_size):
2143+
expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i])
2144+
np.testing.assert_allclose(result_vecmat, expected_vecmat)
2145+
2146+
20792147
class TestTensordot:
20802148
def TensorDot(self, axes):
20812149
# Since tensordot is no longer an op, mimic the old op signature

0 commit comments

Comments
 (0)