Skip to content

Add OpFromGraph wrapper around alloc_diag #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6857bea
Add `OpFromGraph` wrapper around `alloc_diag`
jessegrabowski Jul 10, 2024
5604d9a
Refactor `rewrite_det_diag_to_prod_diag` to use `AllocDiag2`
jessegrabowski Jul 10, 2024
b0abe17
Save arguments passed to `alloc_diag` as properties in `AllocDiag2`
jessegrabowski Jul 10, 2024
dbfe92c
Fix bug in `rewrite_det_diag_to_prod_diag` where batch case was incor…
jessegrabowski Jul 10, 2024
afe2a65
Remove `AllocDiag2` from graphs in the `specialization` phase, after …
jessegrabowski Jul 10, 2024
e810df0
Remove debugging code, formatting
jessegrabowski Jul 10, 2024
a5355b6
Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag`
jessegrabowski Jul 12, 2024
6e37d26
Correctly register `eagerly_inline_alloc_diag` as a JAX-only rewrite
jessegrabowski Jul 12, 2024
f6f27ec
Use `self` (not `type(self)`) in `OpFromGraph.make_node`
jessegrabowski Jul 12, 2024
9486c64
Use base class `OpFromGraph` when constructing `OpFromGraph` gradients
jessegrabowski Jul 12, 2024
abcde3f
Revert "Use `self` (not `type(self)`) in `OpFromGraph.make_node`"
jessegrabowski Jul 12, 2024
66013fc
Solve XY problem
jessegrabowski Jul 12, 2024
e7583d1
Appease mypy
jessegrabowski Jul 12, 2024
70b9cd6
Remove `inline` prop from wrapper class and set `inline=True`
jessegrabowski Jul 12, 2024
9605a0e
Set `inline = False`
jessegrabowski Jul 12, 2024
46fbc55
Add rewrite to inline all `OpFromGraph` `Op`s
jessegrabowski Jul 16, 2024
98cf641
Allow symbolic `offset`
jessegrabowski Jul 16, 2024
c76b54e
Exclude inline rewrite in JAX mode
jessegrabowski Jul 16, 2024
7b44e22
refactor `late_inline_ofg` rewrite to actually perform the correct re…
jessegrabowski Jul 16, 2024
1dfc3fc
Narrow scope of `late_inline` rewrite
jessegrabowski Jul 16, 2024
9f32661
Fix tests
jessegrabowski Jul 17, 2024
f30a63f
Remove `is_inline` prop
jessegrabowski Jul 17, 2024
bf705f9
Add JAX `OpFromGraph` test
jessegrabowski Jul 17, 2024
c8958a4
Don't omit `inline_ofg` rewrites in JAX mode
jessegrabowski Jul 17, 2024
665c766
Don't inline `KroneckerProduct`
jessegrabowski Jul 17, 2024
6487e61
Skip inline rewrite tests when `mode == FAST_COMPILE`
jessegrabowski Jul 17, 2024
43474fd
Incorporate review feedback
jessegrabowski Jul 17, 2024
6ef4084
Incorporate review feedback
jessegrabowski Jul 17, 2024
04ddb46
Add `is_zero_offset` helper to `Eye`
jessegrabowski Jul 17, 2024
c038109
Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `D…
jessegrabowski Jul 18, 2024
19f2895
Seed `test_local_lift_through_linalg` test
jessegrabowski Jul 18, 2024
fcbccde
Fix failing diag_rewrite test
jessegrabowski Jul 18, 2024
56a3ffe
Revert symbolic offset
jessegrabowski Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import jax.numpy as jnp
import numpy as np

import pytensor
from pytensor.graph import node_rewriter
from pytensor.graph.basic import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Alloc,
AllocDiag2,
AllocEmpty,
ARange,
ExtractDiag,
Expand All @@ -21,6 +24,7 @@
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.shape import Shape_i


Expand Down Expand Up @@ -205,3 +209,29 @@ def tri(*args):
return jnp.tri(*args, dtype=op.dtype)

return tri


@register_specialize
@node_rewriter([AllocDiag2])
def eagerly_inline_alloc_diag(fgraph, node):
"""
Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified

Parameters
----------
fgraph: FunctionGraph
The function graph being rewritten
node: Apply
Node of the function graph to be optimized

Returns
-------

"""
[input] = node.inputs
[output] = node.op.inner_outputs
inner_input = output.owner.inputs[1]

inline = pytensor.clone_replace(output, {inner_input: input})

return [inline]
22 changes: 21 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
Expand Down Expand Up @@ -3831,6 +3832,23 @@ def __setstate__(self, state):
self.axis2 = 1


class AllocDiag2(OpFromGraph):
"""
Wrapper Op for alloc_diag graphs
"""

__props__ = ("offset", "axis1", "axis2", "inline")

def __init__(self, *args, offset, axis1, axis2, **kwargs):
inline = kwargs.pop("inline", False)
self.offset = offset
self.axis1 = axis1
self.axis2 = axis2
self.inline = inline

super().__init__(*args, **kwargs, strict=True, inline=inline)


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
"""Insert a vector on the diagonal of a zero-ed matrix.

Expand Down Expand Up @@ -3865,7 +3883,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

return result
return AllocDiag2(
inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2
)(diag)


def diag(v, k=0):
Expand Down
113 changes: 69 additions & 44 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
from pytensor.tensor.basic import (
AllocDiag2,
Eye,
TensorVariable,
diagonal,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
Expand Down Expand Up @@ -41,7 +45,6 @@
solve,
solve_triangular,
)
from pytensor.tensor.subtensor import advanced_set_subtensor


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -401,30 +404,59 @@ def _find_diag_from_eye_mul(potential_mul_input):
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
if mul_input.owner
and (
isinstance(mul_input.owner.op, Eye)
or
# This whole condition checks if there is an Eye hiding inside a DimShuffle.
# This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
# tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
(
isinstance(mul_input.owner.op, DimShuffle)
and mul_input.owner.inputs[0].owner is not None
and isinstance(mul_input.owner.inputs[0].owner.op, Eye)
)
)
]

# Check if 1's are being put on the main diagonal only (k = 0)
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
if not eye_input:
return None

# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
eye_input = eye_input[0]

# If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
if isinstance(eye_input.owner.op, DimShuffle):
inner_eye = eye_input.owner.inputs[0]
if not isinstance(inner_eye.owner.op, Eye):
return None
# Check if 1's are being put on the main diagonal only (k = 0)
# and if the identity matrix is degenerate (column or row matrix)
if getattr(
inner_eye.owner.inputs[-1], "data", -1
).item() != 0 or inner_eye.broadcastable[-2:] != (False, False):
return None

elif getattr(
eye_input.owner.inputs[-1], "data", -1
).item() != 0 or eye_input.broadcastable[-2:] != (False, False):
return None

# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
return eye_input, non_eye_inputs


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node):
def rewrite_det_diag_to_prod_diag(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its
diagonal elements.

The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
vector or a matrix.

Parameters
----------
Expand All @@ -438,53 +470,46 @@ def rewrite_det_diag_from_eye_mul(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None:
inputs = node.inputs[0]

# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag2)
and inputs.owner.op.offset == 0
):
diag_input = inputs.owner.inputs[0]
det_val = diag_input.prod(axis=-1)
return [det_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)

if inputs_or_none is None:
return None
eye_input, non_eye_inputs = eye_non_eye_inputs

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]

# Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True):
if non_eye_input.type.broadcastable[-2:] == (True, True):
# For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0])
elif non_eye_input.type.broadcastable[-2:] == (False, False):
# For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else:
# For vector
det_val = useful_non_eye.prod(axis=(-1, -2))
det_val = non_eye_input.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val]


arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down
11 changes: 9 additions & 2 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,19 @@ def test_det_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x

# Calculating determinant value using pt.linalg.det
z_det = pt.linalg.det(y)

# REWRITE TEST
f_rewritten = function([x], z_det, mode="FAST_RUN")
with pytensor.config.change_flags(optimizer_verbose=True):
f_rewritten = function([x], z_det, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)

assert not any(
isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det)
for node in nodes
)

# NUMERIC VALUE TEST
if len(shape) == 0:
Expand All @@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape):
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)

x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
Expand Down
Loading