Skip to content

Vectorize softmax and argmax nodes #571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@
get_vector_length,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.elemwise import (
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import (
Shape,
Expand Down Expand Up @@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):


@_vectorize_node.register(ExtractDiag)
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
def vectorize_extract_diag(op: ExtractDiag, node, batch_x):
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
batch_axis1, batch_axis2 = get_normalized_batch_axes(
(op.axis1, op.axis2), core_ndim, batch_ndim
)

return diagonal(
batched_x,
batch_x,
offset=op.offset,
axis1=op.axis1 + batched_ndims,
axis2=op.axis2 + batched_ndims,
axis1=batch_axis1,
axis2=batch_axis2,
).owner


Expand Down
50 changes: 35 additions & 15 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import copy
from typing import Union

import numpy as np
from numpy.core.numeric import normalize_axis_tuple

import pytensor.tensor.basic
from pytensor.configdefaults import config
Expand Down Expand Up @@ -1399,7 +1401,7 @@ def make_node(self, input):
# scalar inputs are treated as 1D regarding axis in this `Op`
if axis is not None:
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims))
axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)

Expand Down Expand Up @@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
return DimShuffle(input_broadcastable, new_order).make_node(x)


@_vectorize_node.register(CAReduce)
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
axes = op.axis
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
if axes is None:
axes = list(range(node.inputs[0].type.ndim))
def get_normalized_batch_axes(
core_axes: Union[None, int, tuple[int, ...]],
core_ndim: int,
batch_ndim: int,
) -> tuple[int, ...]:
"""Compute batch axes for a batched operation, from the core input ndim and axes.

e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
batch_axes(None, 2, 4) -> (2, 3)

e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
batch_axes(0, 2, 4) -> (2,)

e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3))
batch_axes((0, -1), 3, 4) -> (1, 3)
"""
if core_axes is None:
core_axes = tuple(range(core_ndim))
else:
axes = list(axes)
new_axes = [axis + batched_ndims for axis in axes]
new_op = op.clone(axis=new_axes)
return new_op.make_node(x)
core_axes = normalize_axis_tuple(core_axes, core_ndim)
return tuple(core_axis + batch_ndim for core_axis in core_axes)


@_vectorize_node.register(CAReduce)
def vectorize_careduce(op: CAReduce, node: Apply, batch_x: TensorVariable) -> Apply:
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim

if not batch_ndim:
return node.op.make_node(batch_x)

batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
return op.clone(axis=batch_axes).make_node(batch_x)
35 changes: 25 additions & 10 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
switch,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.elemwise import (
CAReduce,
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import (
DenseTensorType,
Expand Down Expand Up @@ -134,7 +140,7 @@ class MaxAndArgmax(COp):
_f16_ok = True

def __init__(self, axis):
assert isinstance(axis, list)
assert isinstance(axis, (tuple, list))
self.axis = tuple(axis)

def get_params(self, node):
Expand Down Expand Up @@ -465,6 +471,19 @@ def grad(self, inp, grads):
return [x.zeros_like()]


@_vectorize_node.register(Argmax)
@_vectorize_node.register(MaxAndArgmax)
def vectorize_argmax_node(op, node, batch_x):
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim

if not batch_ndim:
return node.op.make_node(batch_x)

batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
return type(op)(axis=batch_axes).make_node(batch_x)


def makeKeepDims(x, y, axis):
"""
Reintroduces in y with length one the axes of x which have been left out
Expand Down Expand Up @@ -671,13 +690,7 @@ def max(x, axis=None, keepdims=False):
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.

# We thus prefer to use MaxAndArgmax, if possible. It does not
# support all axis arguments, so we may need to fall back to CAReduce.

try:
out = max_and_argmax(x, axis)[0]
except Exception:
out = Max(axis)(x)
out = max_and_argmax(x, axis)[0]

if keepdims:
out = makeKeepDims(x, out, axis)
Expand Down Expand Up @@ -2948,9 +2961,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None


@_vectorize_node.register(Dot)
def vectorize_node_to_matmul(op, node, batched_x, batched_y):
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
# If original input is equivalent to a matrix-matrix product,
# return specialized Matmul Op to avoid unnecessary new Ops.
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
Expand Down
13 changes: 8 additions & 5 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import get_vector_length
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst
Expand Down Expand Up @@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes):


@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return cast(Apply, unbroadcast(x, *new_axes).owner)
def _vectorize_unbroadcast(
op: Unbroadcast, node: Apply, batch_x: TensorVariable
) -> Apply:
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim)
return cast(Apply, unbroadcast(batch_x, *batch_axes).owner)
28 changes: 28 additions & 0 deletions pytensor/tensor/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import scipy

from pytensor.graph.basic import Apply
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.math import gamma, gammaln, neg, sum


Expand Down Expand Up @@ -736,6 +738,32 @@ def log_softmax(c, axis=None):
return LogSoftmax(axis=axis)(c)


@_vectorize_node.register(Softmax)
@_vectorize_node.register(LogSoftmax)
def vectorize_softmax_node(op, node, batched_x):
"""
Vectorize Softmax and LogSoftmax nodes.

"""
core_ndim = node.inputs[0].type.ndim
batch_ndim = batched_x.type.ndim - core_ndim

if not batch_ndim:
return op.make_node(batched_x)

batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)

if len(batch_axes) > 1:
from pytensor.tensor.blockwise import vectorize_node_fallback

# The softmax Ops only allow a specific axis (integer) or all axis (None).
# If the vectorized operation requires more than one axis we have to default to a Blockwise
return vectorize_node_fallback(op, node, batched_x)

[batch_axis] = batch_axes
return type(op)(axis=batch_axis).make_node(batched_x)


def poch(z, m):
"""
Pochhammer symbol (rising factorial) function.
Expand Down
30 changes: 30 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, applys_between
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint
Expand Down Expand Up @@ -1010,6 +1011,35 @@ def test_numpy_input(self):
assert max_pt.eval() == 3
assert argmax_pt.eval() == 2

@pytest.mark.parametrize(
"core_axis, batch_axis",
[
(None, (1, 2, 3, 4)),
(0, (1,)),
((1, -1), (2, 4)),
],
)
def test_vectorize(self, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))

# Test MaxAndArgmax
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)

new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, MaxAndArgmax)
assert new_node.op.axis == batch_axis

# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node = Argmax(axis=node.op.axis).make_node(x)

new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis


class TestArgminArgmax:
def setup_method(self):
Expand Down
32 changes: 31 additions & 1 deletion tests/tensor/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.replace import vectorize_node
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.special import (
LogSoftmax,
Softmax,
Expand All @@ -19,7 +21,7 @@
poch,
softmax,
)
from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors
from pytensor.tensor.type import matrix, tensor, tensor3, tensor4, vector, vectors
from tests import unittest_tools as utt
from tests.tensor.utils import random_ranged

Expand Down Expand Up @@ -150,6 +152,34 @@ def test_valid_axis(self):
SoftmaxGrad(-4)(*x)


@pytest.mark.parametrize(
"core_axis, batch_axis",
[
(None, (1, 2, 3, 4)),
(0, (1,)),
],
)
@pytest.mark.parametrize(
"op, constructor", [(Softmax, softmax), (LogSoftmax, log_softmax)]
)
def test_vectorize_softmax(op, constructor, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))

node = constructor(x, axis=core_axis).owner
assert isinstance(node.op, op)

new_node = vectorize_node(node, batch_x)
if len(batch_axis) == 1:
assert isinstance(new_node.op, op)
assert (new_node.op.axis,) == batch_axis
else:
assert isinstance(new_node.op, Blockwise) and isinstance(
new_node.op.core_op, op
)
assert new_node.op.core_op.axis == core_axis


def test_poch():
_z, _m = vectors("z", "m")
actual_fn = function([_z, _m], poch(_z, _m))
Expand Down