Skip to content

Commit 65967fe

Browse files
committed
Implement vectorize_node for Softmax and Argmax Ops
Also refactors shared logic for other batch axed Ops
1 parent 08a9ba3 commit 65967fe

File tree

7 files changed

+169
-29
lines changed

7 files changed

+169
-29
lines changed

pytensor/tensor/basic.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@
4343
get_vector_length,
4444
)
4545
from pytensor.tensor.blockwise import Blockwise
46-
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
46+
from pytensor.tensor.elemwise import (
47+
DimShuffle,
48+
Elemwise,
49+
get_normalized_batch_axes,
50+
scalar_elemwise,
51+
)
4752
from pytensor.tensor.exceptions import NotScalarConstantError
4853
from pytensor.tensor.shape import (
4954
Shape,
@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
36143619

36153620

36163621
@_vectorize_node.register(ExtractDiag)
3617-
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
3618-
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
3622+
def vectorize_extract_diag(op: ExtractDiag, node, batch_x):
3623+
core_ndim = node.inputs[0].type.ndim
3624+
batch_ndim = batch_x.type.ndim - core_ndim
3625+
batch_axis1, batch_axis2 = get_normalized_batch_axes(
3626+
(op.axis1, op.axis2), core_ndim, batch_ndim
3627+
)
3628+
36193629
return diagonal(
3620-
batched_x,
3630+
batch_x,
36213631
offset=op.offset,
3622-
axis1=op.axis1 + batched_ndims,
3623-
axis2=op.axis2 + batched_ndims,
3632+
axis1=batch_axis1,
3633+
axis2=batch_axis2,
36243634
).owner
36253635

36263636

pytensor/tensor/elemwise.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from copy import copy
2+
from typing import Union
23

34
import numpy as np
5+
from numpy.core.numeric import normalize_axis_tuple
46

57
import pytensor.tensor.basic
68
from pytensor.configdefaults import config
@@ -1399,7 +1401,7 @@ def make_node(self, input):
13991401
# scalar inputs are treated as 1D regarding axis in this `Op`
14001402
if axis is not None:
14011403
try:
1402-
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims))
1404+
axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims))
14031405
except np.AxisError:
14041406
raise np.AxisError(axis, ndim=inp_dims)
14051407

@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
17571759
return DimShuffle(input_broadcastable, new_order).make_node(x)
17581760

17591761

1760-
@_vectorize_node.register(CAReduce)
1761-
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
1762-
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1763-
if not batched_ndims:
1764-
return node.op.make_node(x)
1765-
axes = op.axis
1766-
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
1767-
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
1768-
if axes is None:
1769-
axes = list(range(node.inputs[0].type.ndim))
1762+
def get_normalized_batch_axes(
1763+
core_axes: Union[None, int, tuple[int, ...]],
1764+
core_ndim: int,
1765+
batch_ndim: int,
1766+
) -> tuple[int, ...]:
1767+
"""Compute batch axes for a batched operation, from the core input ndim and axes.
1768+
1769+
e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
1770+
batch_axes(None, 2, 4) -> (2, 3)
1771+
1772+
e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
1773+
batch_axes(0, 2, 4) -> (2,)
1774+
1775+
e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3))
1776+
batch_axes((0, -1), 3, 4) -> (1, 3)
1777+
"""
1778+
if core_axes is None:
1779+
core_axes = tuple(range(core_ndim))
17701780
else:
1771-
axes = list(axes)
1772-
new_axes = [axis + batched_ndims for axis in axes]
1773-
new_op = op.clone(axis=new_axes)
1774-
return new_op.make_node(x)
1781+
core_axes = normalize_axis_tuple(core_axes, core_ndim)
1782+
return tuple(core_axis + batch_ndim for core_axis in core_axes)
1783+
1784+
1785+
@_vectorize_node.register(CAReduce)
1786+
def vectorize_careduce(op: CAReduce, node: Apply, batch_x: TensorVariable) -> Apply:
1787+
core_ndim = node.inputs[0].type.ndim
1788+
batch_ndim = batch_x.type.ndim - core_ndim
1789+
1790+
if not batch_ndim:
1791+
return node.op.make_node(batch_x)
1792+
1793+
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
1794+
return op.clone(axis=batch_axes).make_node(batch_x)

pytensor/tensor/math.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
switch,
2828
)
2929
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
30-
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
30+
from pytensor.tensor.elemwise import (
31+
CAReduce,
32+
DimShuffle,
33+
Elemwise,
34+
get_normalized_batch_axes,
35+
scalar_elemwise,
36+
)
3137
from pytensor.tensor.shape import shape, specify_broadcastable
3238
from pytensor.tensor.type import (
3339
DenseTensorType,
@@ -134,7 +140,7 @@ class MaxAndArgmax(COp):
134140
_f16_ok = True
135141

136142
def __init__(self, axis):
137-
assert isinstance(axis, list)
143+
assert isinstance(axis, (tuple, list))
138144
self.axis = tuple(axis)
139145

140146
def get_params(self, node):
@@ -465,6 +471,19 @@ def grad(self, inp, grads):
465471
return [x.zeros_like()]
466472

467473

474+
@_vectorize_node.register(Argmax)
475+
@_vectorize_node.register(MaxAndArgmax)
476+
def vectorize_argmax_node(op, node, batch_x):
477+
core_ndim = node.inputs[0].type.ndim
478+
batch_ndim = batch_x.type.ndim - core_ndim
479+
480+
if not batch_ndim:
481+
return node.op.make_node(batch_x)
482+
483+
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
484+
return type(op)(axis=batch_axes).make_node(batch_x)
485+
486+
468487
def makeKeepDims(x, y, axis):
469488
"""
470489
Reintroduces in y with length one the axes of x which have been left out

pytensor/tensor/shape.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.tensor import _get_vector_length, as_tensor_variable
1919
from pytensor.tensor import basic as ptb
2020
from pytensor.tensor import get_vector_length
21+
from pytensor.tensor.elemwise import get_normalized_batch_axes
2122
from pytensor.tensor.exceptions import NotScalarConstantError
2223
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
2324
from pytensor.tensor.type_other import NoneConst
@@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes):
11031104

11041105

11051106
@_vectorize_node.register(Unbroadcast)
1106-
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
1107-
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1108-
old_axes = op.axes
1109-
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
1110-
return cast(Apply, unbroadcast(x, *new_axes).owner)
1107+
def _vectorize_unbroadcast(
1108+
op: Unbroadcast, node: Apply, batch_x: TensorVariable
1109+
) -> Apply:
1110+
core_ndim = node.inputs[0].type.ndim
1111+
batch_ndim = batch_x.type.ndim - core_ndim
1112+
batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim)
1113+
return cast(Apply, unbroadcast(batch_x, *batch_axes).owner)

pytensor/tensor/special.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import scipy
55

66
from pytensor.graph.basic import Apply
7+
from pytensor.graph.replace import _vectorize_node
78
from pytensor.link.c.op import COp
89
from pytensor.tensor.basic import as_tensor_variable
10+
from pytensor.tensor.elemwise import get_normalized_batch_axes
911
from pytensor.tensor.math import gamma, gammaln, neg, sum
1012

1113

@@ -736,6 +738,32 @@ def log_softmax(c, axis=None):
736738
return LogSoftmax(axis=axis)(c)
737739

738740

741+
@_vectorize_node.register(Softmax)
742+
@_vectorize_node.register(LogSoftmax)
743+
def vectorize_softmax_node(op, node, batched_x):
744+
"""
745+
Vectorize Softmax and LogSoftmax nodes.
746+
747+
"""
748+
core_ndim = node.inputs[0].type.ndim
749+
batch_ndim = batched_x.type.ndim - core_ndim
750+
751+
if not batch_ndim:
752+
return op.make_node(batched_x)
753+
754+
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
755+
756+
if len(batch_axes) > 1:
757+
from pytensor.tensor.blockwise import vectorize_node_fallback
758+
759+
# The softmax Ops only allow a specific axis (integer) or all axis (None).
760+
# If the vectorized operation requires more than one axis we have to default to a Blockwise
761+
return vectorize_node_fallback(op, node, batched_x)
762+
763+
[batch_axis] = batch_axes
764+
return type(op)(axis=batch_axis).make_node(batched_x)
765+
766+
739767
def poch(z, m):
740768
"""
741769
Pochhammer symbol (rising factorial) function.

tests/tensor/test_math.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
2121
from pytensor.graph.basic import Variable, applys_between
2222
from pytensor.graph.fg import FunctionGraph
23+
from pytensor.graph.replace import vectorize_node
2324
from pytensor.link.c.basic import DualLinker
2425
from pytensor.misc.safe_asarray import _asarray
2526
from pytensor.printing import pprint
@@ -1010,6 +1011,35 @@ def test_numpy_input(self):
10101011
assert max_pt.eval() == 3
10111012
assert argmax_pt.eval() == 2
10121013

1014+
@pytest.mark.parametrize(
1015+
"core_axis, batch_axis",
1016+
[
1017+
(None, (1, 2, 3, 4)),
1018+
(0, (1,)),
1019+
((1, -1), (2, 4)),
1020+
],
1021+
)
1022+
def test_vectorize(self, core_axis, batch_axis):
1023+
x = tensor(shape=(5, 5, 5, 5))
1024+
batch_x = tensor(shape=(3, 5, 5, 5, 5))
1025+
1026+
# Test MaxAndArgmax
1027+
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
1028+
node = max_x.owner
1029+
assert isinstance(node.op, MaxAndArgmax)
1030+
1031+
new_node = vectorize_node(node, batch_x)
1032+
assert isinstance(new_node.op, MaxAndArgmax)
1033+
assert new_node.op.axis == batch_axis
1034+
1035+
# Test Argmax
1036+
# Argmax is not user-facing, so we have to create it manually
1037+
node = Argmax(axis=node.op.axis).make_node(x)
1038+
1039+
new_node = vectorize_node(node, batch_x)
1040+
assert isinstance(new_node.op, Argmax)
1041+
assert new_node.op.axis == batch_axis
1042+
10131043

10141044
class TestArgminArgmax:
10151045
def setup_method(self):

tests/tensor/test_special.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from pytensor.compile.function import function
1010
from pytensor.configdefaults import config
11+
from pytensor.graph.replace import vectorize_node
12+
from pytensor.tensor.blockwise import Blockwise
1113
from pytensor.tensor.special import (
1214
LogSoftmax,
1315
Softmax,
@@ -19,7 +21,7 @@
1921
poch,
2022
softmax,
2123
)
22-
from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors
24+
from pytensor.tensor.type import matrix, tensor, tensor3, tensor4, vector, vectors
2325
from tests import unittest_tools as utt
2426
from tests.tensor.utils import random_ranged
2527

@@ -150,6 +152,34 @@ def test_valid_axis(self):
150152
SoftmaxGrad(-4)(*x)
151153

152154

155+
@pytest.mark.parametrize(
156+
"core_axis, batch_axis",
157+
[
158+
(None, (1, 2, 3, 4)),
159+
(0, (1,)),
160+
],
161+
)
162+
@pytest.mark.parametrize(
163+
"op, constructor", [(Softmax, softmax), (LogSoftmax, log_softmax)]
164+
)
165+
def test_vectorize_softmax(op, constructor, core_axis, batch_axis):
166+
x = tensor(shape=(5, 5, 5, 5))
167+
batch_x = tensor(shape=(3, 5, 5, 5, 5))
168+
169+
node = constructor(x, axis=core_axis).owner
170+
assert isinstance(node.op, op)
171+
172+
new_node = vectorize_node(node, batch_x)
173+
if len(batch_axis) == 1:
174+
assert isinstance(new_node.op, op)
175+
assert (new_node.op.axis,) == batch_axis
176+
else:
177+
assert isinstance(new_node.op, Blockwise) and isinstance(
178+
new_node.op.core_op, op
179+
)
180+
assert new_node.op.core_op.axis == core_axis
181+
182+
153183
def test_poch():
154184
_z, _m = vectors("z", "m")
155185
actual_fn = function([_z, _m], poch(_z, _m))

0 commit comments

Comments
 (0)