Skip to content

Commit c52154d

Browse files
committed
Add rewrite for matmul when only one of the inputs has batched dimensions
This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output. This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines. The idea was taken from these two discussions: numpy/numpy#7569 numpy/numpy#8957
1 parent 0fc2cd8 commit c52154d

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
constant,
3232
extract_constant,
3333
get_underlying_scalar_constant_value,
34+
moveaxis,
3435
ones_like,
3536
register_infer_shape,
3637
switch,
3738
zeros_like,
3839
)
40+
from pytensor.tensor.blockwise import Blockwise
3941
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4042
from pytensor.tensor.exceptions import NotScalarConstantError
4143
from pytensor.tensor.extra_ops import broadcast_arrays
@@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node):
217219
return ret
218220

219221

222+
@register_stabilize
223+
@register_specialize
224+
@node_rewriter(tracks=[Blockwise])
225+
def local_batched_matmul_to_core_matmul(fgraph, node):
226+
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
227+
228+
Example, if x has batch dimensions, but y not:
229+
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
230+
231+
It also works when y has batch dimensions, but x not.
232+
"""
233+
234+
# Check whether we have a matmul operation in this node
235+
if not (
236+
isinstance(node.op.core_op, Dot)
237+
and len(node.op.inputs_sig[0]) == 2
238+
and len(node.op.inputs_sig[1]) == 2
239+
):
240+
return None
241+
242+
x, y = node.inputs
243+
batch_ndim = node.op.batch_ndim(node)
244+
245+
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
246+
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
247+
y.type.broadcastable[:-2]
248+
):
249+
x_stacked = x.reshape((-1, x.shape[-1]))
250+
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
251+
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
252+
return [out]
253+
254+
# Otherwise, check if y has batch dimension, but x not
255+
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
256+
x.type.broadcastable[:-2]
257+
):
258+
# For the y batch case we need to first move the batch axes and then reshape
259+
# y.shape == (*b, k, n)
260+
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
261+
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
262+
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
263+
out_stacked_tr = out_stacked.reshape(
264+
(x.shape[-2], *y.shape[:-2], y.shape[-1])
265+
) # (m, *b, n)
266+
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
267+
return [out]
268+
269+
# Both x and y have batch dimensions, nothing to do here
270+
return None
271+
272+
220273
def is_inverse_pair(node_op, prev_op, inv_pair):
221274
"""
222275
Given two consecutive operations, check if they are the

tests/tensor/rewriting/test_math.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytensor.tensor.basic import Alloc, constant, join, second, switch
3535
from pytensor.tensor.blas import Dot22, Gemv
3636
from pytensor.tensor.blas_c import CGemv
37+
from pytensor.tensor.blockwise import Blockwise
3738
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3839
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
3940
from pytensor.tensor.math import abs as pt_abs
@@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
44274428
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
44284429
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
44294430
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
4431+
4432+
4433+
@pytest.mark.skipif(
4434+
config.mode == "FAST_COMPILE",
4435+
reason="Rewrite is only relevant in FAST_RUN",
4436+
)
4437+
def test_local_batched_matmul_to_core_matmul():
4438+
rng = np.random.default_rng(seed=4433)
4439+
4440+
# x is batched but not y
4441+
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
4442+
y = pt.tensor("y", shape=(2, 2), dtype="float64")
4443+
out = x @ y
4444+
assert isinstance(out.owner.op, Blockwise)
4445+
4446+
fn = pytensor.function([x, y], out)
4447+
assert not any(
4448+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
4449+
)
4450+
4451+
x_test = rng.normal(size=(5, 3, 2))
4452+
y_test = rng.normal(size=(2, 2))
4453+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
4454+
4455+
# y is batched but not x
4456+
x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
4457+
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
4458+
out = x @ y
4459+
assert isinstance(out.owner.op, Blockwise)
4460+
4461+
fn = pytensor.function([x, y], out)
4462+
assert not any(
4463+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
4464+
)
4465+
4466+
x_test = rng.normal(size=(1, 3, 2))
4467+
y_test = rng.normal(size=(5, 2, 2))
4468+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
4469+
4470+
# Both x and y are batched, rewrite does not apply
4471+
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
4472+
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
4473+
out = x @ y
4474+
4475+
fn = pytensor.function([x, y], out)
4476+
x_test = rng.normal(size=(5, 3, 2))
4477+
y_test = rng.normal(size=(5, 2, 2))
4478+
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)

0 commit comments

Comments
 (0)